diff --git a/.rat-excludes b/.rat-excludes index 72771465846b..9165872b9fb2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -94,3 +94,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/R/README.md b/R/README.md index 005f56da1670..d8e75ea75260 100644 --- a/R/README.md +++ b/R/README.md @@ -63,5 +63,7 @@ You can also run the unit-tests for SparkR by running (you need to install the [ The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf +./bin/spark-submit --master yarn --deploy-mode cluster (or client) examples/src/main/r/dataframe.R +OR ./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/install-dev.bat b/R/install-dev.bat index f32670b67de9..008a5c668bc4 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,8 +25,3 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ - -rem Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd %SPARK_HOME%\R\lib -%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR -popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 4949d86d20c9..d0d7201f004a 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.4.0 +Version: 1.5.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' 'mllib.R' 'serialize.R' 'sparkR.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b2d92bdf4840..3e5c89d779b7 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -84,57 +84,136 @@ exportClasses("Column") exportMethods("abs", "acos", + "add_months", "alias", "approxCountDistinct", "asc", + "ascii", "asin", "atan", "atan2", "avg", + "base64", "between", + "bin", + "bitwiseNOT", "cast", "cbrt", + "ceil", "ceiling", + "concat", + "concat_ws", "contains", + "conv", "cos", "cosh", + "count", "countDistinct", + "crc32", + "date_add", + "date_format", + "date_sub", + "datediff", + "dayofmonth", + "dayofyear", "desc", "endsWith", "exp", + "explode", "expm1", + "expr", + "factorial", + "first", "floor", + "format_number", + "format_string", + "from_unixtime", + "from_utc_timestamp", "getField", "getItem", + "greatest", + "hex", + "hour", "hypot", + "ifelse", + "initcap", + "instr", + "isNaN", "isNotNull", "isNull", "last", + "last_day", + "least", + "length", + "levenshtein", "like", + "lit", + "locate", "log", "log10", "log1p", + "log2", "lower", + "lpad", + "ltrim", "max", + "md5", "mean", "min", + "minute", + "month", + "months_between", "n", "n_distinct", + "nanvl", + "negate", + "next_day", + "otherwise", + "pmod", + "quarter", + "rand", + "randn", + "regexp_extract", + "regexp_replace", + "reverse", "rint", "rlike", + "round", + "rpad", + "rtrim", + "second", + "sha1", + "sha2", + "shiftLeft", + "shiftRight", + "shiftRightUnsigned", "sign", + "signum", "sin", "sinh", + "size", + "soundex", "sqrt", "startsWith", "substr", + "substring_index", "sum", "sumDistinct", "tan", "tanh", "toDegrees", "toRadians", - "upper") + "to_date", + "to_utc_timestamp", + "translate", + "trim", + "unbase64", + "unhex", + "unix_timestamp", + "upper", + "weekofyear", + "when", + "year") exportClasses("GroupedData") exportMethods("agg") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index eeaf9f193b72..5a07ebd30829 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -60,12 +60,6 @@ operators <- list( ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +115,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -243,44 +204,16 @@ setMethod("%in%", return(column(jc)) }) -#' Approx Count Distinct +#' otherwise #' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", - signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) - }) - -#' Count Distinct +#' If values in the specified column are null, returns the value. +#' Can be used in conjunction with `when` to specify a default value for expressions. #' #' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) +setMethod("otherwise", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 6d364f77be7e..33bf13ec9e78 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -176,10 +176,14 @@ readRow <- function(inputCon) { # Take a single column as Array[Byte] and deserialize it into an atomic vector readCol <- function(inputCon, numRows) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) + if (numRows > 0) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { + value <- readObject(inputCon) + # Replace NULL with NA so we can coerce to vectors + if (is.null(value)) NA else value + })) + } else { + vector() + } } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 000000000000..b5879bd9ad55 --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,615 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' @title S4 expression functions for DataFrame column(s) +#' @description These are expression functions on DataFrame columns + +functions1 <- c( + "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan", + "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count", + "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial", + "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day", + "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5", + "mean", "min", "minute", "month", "negate", "quarter", "reverse", + "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size", + "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees", + "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear", + "year") +functions2 <- c( + "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod") + +createFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createFunction2 <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + +createFunctions <- function() { + for (name in functions1) { + createFunction1(name) + } + for (name in functions2) { + createFunction2(name) + } +} + +createFunctions() + +#' @rdname functions +#' @return Creates a Column class of literal value. +setMethod("lit", signature("ANY"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lit", + ifelse(class(x) == "Column", x@jc, x)) + column(jc) + }) + +#' Approx Count Distinct +#' +#' @rdname functions +#' @return the approximate number of distinct items in a group. +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @rdname functions +#' @return the number of distinct items in a group. +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + +#' @rdname functions +#' @return Concatenates multiple input string columns together into a single string column. +setMethod("concat", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + column(jc) + }) + +#' @rdname functions +#' @return Returns the greatest value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +setMethod("greatest", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + column(jc) + }) + +#' @rdname functions +#' @return Returns the least value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null iff all parameters are null. +setMethod("least", + signature(x = "Column"), + function(x, ...) { + stopifnot(length(list(...)) > 0) + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + column(jc) + }) + +#' @rdname functions +#' @aliases ceil +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' @rdname functions +#' @aliases signum +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @rdname functions +#' @aliases countDistinct +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' @rdname functions +#' @aliases count +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) + +#' date_format +#' +#' Converts a date/timestamp/string to a value of string in the format specified by the date +#' format given by the second argument. +#' +#' A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All +#' pattern letters of `java.text.SimpleDateFormat` can be used. +#' +#' NOTE: Use when ever possible specialized functions like `year`. These benefit from a +#' specialized implementation. +#' +#' @rdname functions +setMethod("date_format", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) + column(jc) + }) + +#' from_utc_timestamp +#' +#' Assumes given timestamp is UTC and converts to given timezone. +#' +#' @rdname functions +setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) + column(jc) + }) + +#' instr +#' +#' Locate the position of the first occurrence of substr column in the given string. +#' Returns null if either of the arguments are null. +#' +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @rdname functions +setMethod("instr", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) + column(jc) + }) + +#' next_day +#' +#' Given a date column, returns the first date which is later than the value of the date column +#' that is on the specified day of the week. +#' +#' For example, `next <- day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first +#' Sunday after 2015-07-27. +#' +#' Day of the week parameter is case insensitive, and accepts: +#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' +#' @rdname functions +setMethod("next_day", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) + column(jc) + }) + +#' to_utc_timestamp +#' +#' Assumes given timestamp is in given timezone and converts to UTC. +#' +#' @rdname functions +setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) + column(jc) + }) + +#' add_months +#' +#' Returns the date that is numMonths after startDate. +#' +#' @rdname functions +setMethod("add_months", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) + column(jc) + }) + +#' date_add +#' +#' Returns the date that is `days` days after `start` +#' +#' @rdname functions +setMethod("date_add", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) + column(jc) + }) + +#' date_sub +#' +#' Returns the date that is `days` days before `start` +#' +#' @rdname functions +setMethod("date_sub", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) + column(jc) + }) + +#' format_number +#' +#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' and returns the result as a string column. +#' +#' If d is 0, the result has no decimal point or fractional part. +#' If d < 0, the result will be null.' +#' +#' @rdname functions +setMethod("format_number", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "format_number", + y@jc, as.integer(x)) + column(jc) + }) + +#' sha2 +#' +#' Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. +#' +#' @rdname functions +#' @param y column to compute SHA-2 on. +#' @param x one of 224, 256, 384, or 512. +setMethod("sha2", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) + column(jc) + }) + +#' shiftLeft +#' +#' Shift the the given value numBits left. If the given value is a long value, this function +#' will return a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftLeft", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftLeft", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRight +#' +#' Shift the the given value numBits right. If the given value is a long value, it will return +#' a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftRight", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRight", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRightUnsigned +#' +#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRightUnsigned", + y@jc, as.integer(x)) + column(jc) + }) + +#' concat_ws +#' +#' Concatenates multiple input string columns together into a single string column, +#' using the given separator. +#' +#' @rdname functions +setMethod("concat_ws", signature(sep = "character", x = "Column"), + function(sep, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) + column(jc) + }) + +#' conv +#' +#' Convert a number in a string column from one base to another. +#' +#' @rdname functions +setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), + function(x, fromBase, toBase) { + fromBase <- as.integer(fromBase) + toBase <- as.integer(toBase) + jc <- callJStatic("org.apache.spark.sql.functions", + "conv", + x@jc, fromBase, toBase) + column(jc) + }) + +#' expr +#' +#' Parses the expression string into the column that it represents, similar to +#' DataFrame.selectExpr +#' +#' @rdname functions +setMethod("expr", signature(x = "character"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) + column(jc) + }) + +#' format_string +#' +#' Formats the arguments in printf-style and returns the result as a string column. +#' +#' @rdname functions +setMethod("format_string", signature(format = "character", x = "Column"), + function(format, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", + "format_string", + format, jcols) + column(jc) + }) + +#' from_unixtime +#' +#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string +#' representing the timestamp of that moment in the current system time zone in the given +#' format. +#' +#' @rdname functions +setMethod("from_unixtime", signature(x = "Column"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", + "from_unixtime", + x@jc, format) + column(jc) + }) + +#' locate +#' +#' Locate the position of the first occurrence of substr. +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @rdname functions +setMethod("locate", signature(substr = "character", str = "Column"), + function(substr, str, pos = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", + "locate", + substr, str@jc, as.integer(pos)) + column(jc) + }) + +#' lpad +#' +#' Left-pad the string column with +#' +#' @rdname functions +setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' rand +#' +#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' +#' @rdname functions +setMethod("rand", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand") + column(jc) + }) +setMethod("rand", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) + column(jc) + }) + +#' randn +#' +#' Generate a column with i.i.d. samples from the standard normal distribution. +#' +#' @rdname functions +setMethod("randn", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn") + column(jc) + }) +setMethod("randn", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) + column(jc) + }) + +#' regexp_extract +#' +#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' +#' @rdname functions +setMethod("regexp_extract", + signature(x = "Column", pattern = "character", idx = "numeric"), + function(x, pattern, idx) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_extract", + x@jc, pattern, as.integer(idx)) + column(jc) + }) + +#' regexp_replace +#' +#' Replace all substrings of the specified string value that match regexp with rep. +#' +#' @rdname functions +setMethod("regexp_replace", + signature(x = "Column", pattern = "character", replacement = "character"), + function(x, pattern, replacement) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_replace", + x@jc, pattern, replacement) + column(jc) + }) + +#' rpad +#' +#' Right-padded with pad to a length of len. +#' +#' @rdname functions +setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "rpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' substring_index +#' +#' Returns the substring from string str before count occurrences of the delimiter delim. +#' If count is positive, everything the left of the final delimiter (counting from left) is +#' returned. If count is negative, every to the right of the final delimiter (counting from the +#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' +#' @rdname functions +setMethod("substring_index", + signature(x = "Column", delim = "character", count = "numeric"), + function(x, delim, count) { + jc <- callJStatic("org.apache.spark.sql.functions", + "substring_index", + x@jc, delim, as.integer(count)) + column(jc) + }) + +#' translate +#' +#' Translate any character in the src by a character in replaceString. +#' The characters in replaceString is corresponding to the characters in matchingString. +#' The translate will happen when any character in the string matching with the character +#' in the matchingString. +#' +#' @rdname functions +setMethod("translate", + signature(x = "Column", matchingString = "character", replaceString = "character"), + function(x, matchingString, replaceString) { + jc <- callJStatic("org.apache.spark.sql.functions", + "translate", x@jc, matchingString, replaceString) + column(jc) + }) + +#' unix_timestamp +#' +#' Gets current Unix timestamp in seconds. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "missing", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") + column(jc) + }) +#' unix_timestamp +#' +#' Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), +#' using the default timezone and the default locale, return null if fail. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) + column(jc) + }) +#' unix_timestamp +#' +#' Convert time string with given pattern +#' (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) +#' to Unix time stamp (in seconds), return null if fail. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "Column", format = "character"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) + column(jc) + }) +#' when +#' +#' Evaluates a list of conditions and returns one of multiple possible result expressions. +#' For unmatched expressions null is returned. +#' +#' @rdname column +setMethod("when", signature(condition = "Column", value = "ANY"), + function(condition, value) { + condition <- condition@jc + value <- ifelse(class(value) == "Column", value@jc, value) + jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) + column(jc) + }) + +#' ifelse +#' +#' Evaluates a list of conditions and returns `yes` if the conditions are satisfied. +#' Otherwise `no` is returned for unmatched conditions. +#' +#' @rdname column +setMethod("ifelse", + signature(test = "Column", yes = "ANY", no = "ANY"), + function(test, yes, no) { + test <- test@jc + yes <- ifelse(class(yes) == "Column", yes@jc, yes) + no <- ifelse(class(no) == "Column", no@jc, no) + jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", + "when", + test, yes), + "otherwise", no) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c43b947129e8..84cb8dfdaa2d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -535,8 +535,8 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) -##' rdname summary -##' @export +#' @rdname summary +#' @export setGeneric("summary", function(x, ...) { standardGeneric("summary") }) # @rdname tojson @@ -575,10 +575,6 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column -#' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) - #' @rdname column #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) @@ -587,13 +583,10 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) + #' @rdname column #' @export setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) @@ -660,20 +653,325 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @rdname column #' @export -setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname column #' @export +setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) + + +###################### Expression Function Methods ########################## + +#' @rdname functions +#' @export +setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) + +#' @rdname functions +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname functions +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname functions +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname functions +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname functions +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname functions +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname functions +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname functions +#' @export +setGeneric("concat", function(x, ...) { standardGeneric("concat") }) + +#' @rdname functions +#' @export +setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) + +#' @rdname functions +#' @export +setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) + +#' @rdname functions +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname functions +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname functions +#' @export +setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) + +#' @rdname functions +#' @export +setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) + +#' @rdname functions +#' @export +setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) + +#' @rdname functions +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname functions +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname functions +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname functions +#' @export +setGeneric("expr", function(x) { standardGeneric("expr") }) + +#' @rdname functions +#' @export +setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) + +#' @rdname functions +#' @export +setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) + +#' @rdname functions +#' @export +setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) + +#' @rdname functions +#' @export +setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) + +#' @rdname functions +#' @export +setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) + +#' @rdname functions +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname functions +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname functions +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname functions +#' @export +setGeneric("instr", function(y, x) { standardGeneric("instr") }) + +#' @rdname functions +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname functions +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname functions +#' @export +setGeneric("least", function(x, ...) { standardGeneric("least") }) + +#' @rdname functions +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname functions +#' @export +setGeneric("lit", function(x) { standardGeneric("lit") }) + +#' @rdname functions +#' @export +setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) + +#' @rdname functions +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname functions +#' @export +setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) + +#' @rdname functions +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname functions +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname functions +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname functions +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname functions +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname functions +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname functions +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname functions +#' @export +setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) + +#' @rdname functions +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname functions +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname functions +#' @export +setGeneric("rand", function(seed) { standardGeneric("rand") }) + +#' @rdname functions +#' @export +setGeneric("randn", function(seed) { standardGeneric("randn") }) + +#' @rdname functions +#' @export +setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) + +#' @rdname functions +#' @export +setGeneric("regexp_replace", + function(x, pattern, replacement) { standardGeneric("regexp_replace") }) + +#' @rdname functions +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname functions +#' @export +setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) + +#' @rdname functions +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname functions +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname functions +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname functions +#' @export +setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) + +#' @rdname functions +#' @export +setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) + +#' @rdname functions +#' @export +setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) + +#' @rdname functions +#' @export +setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) + +#' @rdname functions +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname functions +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname functions +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname functions +#' @export +setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) + +#' @rdname functions +#' @export +setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) + +#' @rdname functions +#' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname functions +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname functions +#' @export +setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) + +#' @rdname functions +#' @export +setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) + +#' @rdname functions +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname functions +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname functions +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname functions +#' @export +setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) + +#' @rdname functions #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname functions +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname functions +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + #' @rdname glm #' @export setGeneric("glm") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd8749..cea3d760d05f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -56,10 +56,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). #' -#' @param model A fitted MLlib model +#' @param object A fitted MLlib model #' @param newData DataFrame for testing #' @return DataFrame containing predicted values -#' @rdname glm +#' @rdname predict #' @export #' @examples #'\dontrun{ @@ -76,10 +76,10 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param model A fitted MLlib model +#' @param x A fitted MLlib model #' @return a list with a 'coefficient' component, which is the matrix of coefficients. See #' summary.glm for more information. -#' @rdname glm +#' @rdname summary #' @export #' @examples #'\dontrun{ diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 7377fc8f1ca9..556b8c544705 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -408,6 +408,14 @@ test_that("collect() returns a data.frame", { expect_equal(names(rdf)[1], "age") expect_equal(nrow(rdf), 3) expect_equal(ncol(rdf), 2) + + # collect() returns data correctly from a DataFrame with 0 row + df0 <- limit(df, 0) + rdf <- collect(df0) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 0) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { @@ -492,6 +500,18 @@ test_that("head() and first() return the correct data", { testFirst <- first(df) expect_equal(nrow(testFirst), 1) + + # head() and first() return the correct data on + # a DataFrame with 0 row + df0 <- limit(df, 0) + + testHead <- head(df0) + expect_equal(nrow(testHead), 0) + expect_equal(ncol(testHead), 2) + + testFirst <- first(df0) + expect_equal(nrow(testFirst), 0) + expect_equal(ncol(testFirst), 2) }) test_that("distinct() and unique on DataFrames", { @@ -560,6 +580,11 @@ test_that("select with column", { df2 <- select(df, df$age) expect_equal(columns(df2), c("age")) expect_equal(count(df2), 3) + + df3 <- select(df, lit("x")) + expect_equal(columns(df3), c("x")) + expect_equal(count(df3), 3) + expect_equal(collect(select(df3, "x"))[[1, 1]], "x") }) test_that("selectExpr() on a DataFrame", { @@ -573,6 +598,11 @@ test_that("selectExpr() on a DataFrame", { expect_equal(count(selected2), 3) }) +test_that("expr() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) +}) + test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) @@ -640,15 +670,17 @@ test_that("column operators", { test_that("column functions", { c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) @@ -661,8 +693,11 @@ test_that("column functions", { expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) -}) + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") +}) +# test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", @@ -681,6 +716,13 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) }) test_that("string operators", { @@ -689,6 +731,94 @@ test_that("string operators", { expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") + expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) +}) + +test_that("greatest() and least() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) + expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) +}) + +test_that("when(), otherwise() and ifelse() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) test_that("group by", { diff --git a/README.md b/README.md index 380422ca00db..2d2d1e2e6b59 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,8 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +examples to a cluster. This can be a mesos:// or spark:// URL, to run on YARN; either --master yarn and set --deploy-mode (cluster or client) or simply set --master as +"yarn-cluster" or "yarn-client", and "local" to run locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 27006e45e932..74c5cea94403 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -10,6 +10,8 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO +log4j.logger.org.apache.parquet=ERROR +log4j.logger.parquet=ERROR # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL diff --git a/core/pom.xml b/core/pom.xml index 0e53a79fd223..4f79d71bf85f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -266,7 +266,7 @@ org.tachyonproject tachyon-client - 0.7.0 + 0.7.1 org.apache.hadoop diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 0399abc63c23..0e58bb4f7101 100644 --- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -25,7 +25,7 @@ import scala.reflect.ClassTag; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -49,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -64,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 925b60a14588..3d1ef0c48adc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -37,7 +37,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -211,16 +211,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -447,14 +443,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 02084f9122e0..2389c28b2839 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.TaskMemoryManager; @Private @@ -244,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7f79cd13aab4..b24eed3952fd 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,25 +17,24 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; import java.util.Iterator; import java.util.LinkedList; import java.util.List; -import javax.annotation.Nullable; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -93,9 +92,9 @@ public final class BytesToBytesMap { /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be - * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since - * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array - * entries per key, giving us a maximum capacity of (1 << 29). + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). */ @VisibleForTesting static final int MAX_CAPACITY = (1 << 29); @@ -193,6 +192,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -270,10 +274,10 @@ public boolean hasNext() { @Override public Location next() { - int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); } loc.with(currentPage, offsetInPage); offsetInPage += 4 + totalLength; @@ -323,6 +327,20 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { assert(bitset != null); assert(longArray != null); @@ -338,7 +356,8 @@ public Location lookup( } if (!bitset.isSet(pos)) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { @@ -356,7 +375,7 @@ public Location lookup( keyRowLengthBytes ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; @@ -402,9 +421,9 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object page, final long offsetInPage) { long position = offsetInPage; - final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + final int totalLength = Platform.getInt(page, position); position += 4; - keyLength = PlatformDependent.UNSAFE.getInt(page, position); + keyLength = Platform.getInt(page, position); position += 4; valueLength = totalLength - keyLength - 4; @@ -572,18 +591,11 @@ public boolean putNewKey( // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -608,21 +620,21 @@ public boolean putNewKey( final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, + Platform.putInt(dataPageBaseObject, recordOffset, keyLengthBytes + valueLengthBytes + 4); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.copyMemory( + Platform.copyMemory( keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, valueDataOffsetInPage, valueLengthBytes); // --- Update bookeeping data structures ----------------------------------------------------- if (useOverflowPage) { // Store the end-of-page marker at the end of the data page - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); } else { pageCursor += requiredSize; } @@ -642,6 +654,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -748,7 +778,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5e002ae1b756..71b76d5ddfaa 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -20,10 +20,9 @@ import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.Utils; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; @Private public class PrefixComparators { @@ -73,7 +72,7 @@ public static long computePrefix(byte[] bytes) { final int minLen = Math.min(bytes.length, 8); long p = 0; for (int i = 0; i < minLen; ++i) { - p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i)) + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) << (56 - 8 * i); } return p; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd..fc364e0a895b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). @@ -427,14 +426,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -493,18 +488,16 @@ public void insertKVRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); dataPagePosition += 4; - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += 4; - PlatformDependent.copyMemory( - keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += keyLen; - PlatformDependent.copyMemory( - valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1e4b8a116e11..f7787e1019c2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,7 +19,7 @@ import java.util.Comparator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -164,7 +164,7 @@ public void loadNext() { final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length - recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + recordLength = Platform.getInt(baseObject, baseOffset - 4); keyPrefix = sortBuffer[position + 1]; position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index ca1ccedc93c8..4989b05d63e2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description @@ -42,7 +42,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; - private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( BlockManager blockManager, diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 44cf6c756d7c..e59a84ff8d11 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -28,7 +28,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -117,11 +117,11 @@ public void write( long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, recordReadPosition, writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), toTransfer); writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4a893bc0189a..83dbea40b63f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -109,13 +109,13 @@ function toggleDagViz(forJob) { } $(function (){ - if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(false), "false"); toggleDagViz(false); - } - - if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(true), "false"); toggleDagViz(true); diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 064246dfa7fc..c39c8667d013 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -382,14 +382,18 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(): Seq[Accumulator[Long]] = { - Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1877aaf2cac5..b93536e6536e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -599,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -618,6 +612,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -628,6 +624,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ff154fb5e33..b344b5e173d6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -389,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -476,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ced44131b0d..1ddaca8a5ba8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -118,9 +118,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") this.preferredNodeLocationData = preferredNodeLocationData } @@ -153,6 +155,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) + if (preferredNodeLocationData.nonEmpty) { + logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") + } this.preferredNodeLocationData = preferredNodeLocationData } @@ -528,7 +533,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) @@ -559,7 +568,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -866,7 +876,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ @@ -1667,7 +1677,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a796e7285019..0f1e2e069568 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -331,6 +331,8 @@ object SparkEnv extends Logging { case "netty" => new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => + logWarning("NIO-based block transfer service is deprecated, " + + "and will be removed in Spark 1.6.0.") new NioBlockTransferService(conf, securityManager) } diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 2ebd7a7151a5..977a27bdfe1b 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -30,3 +30,10 @@ class SparkException(message: String, cause: Throwable) */ private[spark] class SparkDriverExecutionException(cause: Throwable) extends SparkException("Execution error", cause) + +/** + * Exception thrown when the main user code is run as a child process (e.g. pyspark) and we want + * the parent SparkSubmit process to exit with the same exit code. + */ +private[spark] case class SparkUserAppException(exitCode: Int) + extends SparkException(s"User application exited with $exitCode") diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d5..934d00dc708b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -90,6 +92,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +103,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +148,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 000000000000..af483e361e33 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation +import scala.annotation.meta._ + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +@param @field @getter @setter @beanGetter @beanSetter +private[spark] class Since(version: String) extends StaticAnnotation diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 829fae1d1d9b..c582488f16fe 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -354,7 +354,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an array that contains all of the elements in this RDD. * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead */ - @Deprecated + @deprecated("use collect()", "1.0.0") def toArray(): JList[T] = collect() /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 14dac4ed28ce..6ce02e2ea336 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -182,6 +182,7 @@ private[r] class RBackendHandler(server: RBackend) if (parameterType.isPrimitive) { parameterWrapperType = parameterType match { case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] case java.lang.Double.TYPE => classOf[java.lang.Double] case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] case _ => parameterType diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index d5b4260bf452..3c89f2447374 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -181,6 +181,7 @@ private[spark] object SerDe { // Boolean -> logical // Float -> double // Double -> double + // Decimal -> double // Long -> double // Array[Byte] -> raw // Date -> Date @@ -219,6 +220,10 @@ private[spark] object SerDe { case "float" | "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "decimal" | "java.math.BigDecimal" => + writeType(dos, "double") + val javaDecimal = value.asInstanceOf[java.math.BigDecimal] + writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 20a9faa1784b..22ef701d833b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -53,7 +53,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana /** Create a new shuffle block handler. Factored out for subclasses to override. */ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { - new ExternalShuffleBlockHandler(conf) + new ExternalShuffleBlockHandler(conf, null) } /** Starts the external shuffle service if the user has configured us to. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index c2ed43a5397d..23d01e9cbb9f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.util.Try +import org.apache.spark.SparkUserAppException import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -46,7 +47,20 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such val gatewayServer = new py4j.GatewayServer(null, 0) - gatewayServer.start() + val thread = new Thread(new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions { + gatewayServer.start() + } + }) + thread.setName("py4j-gateway-init") + thread.setDaemon(true) + thread.start() + + // Wait until the gateway server has started, so that we know which port is it bound to. + // `gatewayServer.start()` will start a new thread and run the server code there, after + // initializing the socket, so the thread started above will end as soon as the server is + // ready to serve connections. + thread.join() // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument @@ -64,11 +78,18 @@ object PythonRunner { env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize - val process = builder.start() + try { + val process = builder.start() - new RedirectThread(process.getInputStream, System.out, "redirect output").start() + new RedirectThread(process.getInputStream, System.out, "redirect output").start() - System.exit(process.waitFor()) + val exitCode = process.waitFor() + if (exitCode != 0) { + throw new SparkUserAppException(exitCode) + } + } finally { + gatewayServer.shutdown() + } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7e9dba42bebd..dda4216c7efe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -76,7 +76,7 @@ class SparkHadoopUtil extends Logging { } } - @Deprecated + @deprecated("use newConfiguration with SparkConf argument", "1.2.0") def newConfiguration(): Configuration = newConfiguration(null) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7ac6cbce4cd1..86fcf942c2c4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -39,8 +39,8 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.spark.{SparkUserAppException, SPARK_VERSION} import org.apache.spark.api.r.RUtils -import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -422,7 +422,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -433,7 +434,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), @@ -672,7 +672,13 @@ object SparkSubmit { mainMethod.invoke(null, childArgs.toArray) } catch { case t: Throwable => - throw findCause(t) + findCause(t) match { + case SparkUserAppException(exitCode) => + System.exit(exitCode) + + case t: Throwable => + throw t + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e3060ac3fa1a..e573ff16c50a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -126,11 +126,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) + pool.scheduleWithFixedDelay(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } } @@ -204,11 +204,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) mod1 >= mod2 } - logInfos.sliding(20, 20).foreach { batch => - replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(batch) - }) - } + logInfos.grouped(20) + .map { batch => + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(batch) + }) + } + .foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } + } lastModifiedTime = newLastModifiedTime } catch { @@ -272,9 +286,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a076a9c3f984..d4f327cc588f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -238,7 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9217202b69a6..26904d39a9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -127,14 +127,8 @@ private[deploy] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServer = - if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) - } else { - None - } - private val restServerBoundPort = restServer.map(_.start()) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) @@ -148,6 +142,12 @@ private[deploy] class Master( } }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) + masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() applicationMetricsSystem.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 061857476a8a..12337a940a41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -34,7 +34,7 @@ import org.apache.spark.network.util.TransportConf * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. */ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) - extends ExternalShuffleBlockHandler(transportConf) with Logging { + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { // Stores a map of driver socket addresses to app ids private val connectedApps = new mutable.HashMap[SocketAddress, String] diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 29a504228557..ab3fea475c2a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5d78a9dc8885..42a85e42ea2b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer @@ -305,8 +305,16 @@ private[spark] class Executor( m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f1c17369cb48..e1f8719eead0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -274,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f83a051f5da1..6a9c004d65cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -186,7 +186,7 @@ class NewHadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 6a95e44c57fe..fa3fecc80cb6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} private[spark] class SqlNewHadoopPartition( @@ -212,7 +212,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } } catch { case e: Exception => - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index bb489c6b6e98..684db6646765 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -111,7 +111,7 @@ class DAGScheduler( * * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). */ - private val cacheLocs = new HashMap[Int, Seq[Seq[TaskLocation]]] + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with // every task. When we detect a node failing, we note the current epoch number and failed @@ -200,17 +200,17 @@ class DAGScheduler( // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] - def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { // Note: if the storage level is NONE, we don't need to get locations from block manager. - val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { - Seq.fill(rdd.partitions.size)(Nil) + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) } else { val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] @@ -302,12 +302,12 @@ class DAGScheduler( shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd - val numTasks = rdd.partitions.size + val numTasks = rdd.partitions.length val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.size) { + for (i <- 0 until locs.length) { stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } stage.numAvailableOutputs = locs.count(_ != null) @@ -315,7 +315,7 @@ class DAGScheduler( // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage } @@ -566,7 +566,7 @@ class DAGScheduler( properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray + val partitions = (0 until rdd.partitions.length).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) @@ -677,8 +677,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -715,7 +718,7 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite) + finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) @@ -762,7 +765,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -787,9 +790,10 @@ class DAGScheduler( } } + // Create internal accumulators if the stage has no accumulators initialized. // Reset internal accumulators only if this stage is not partially submitted // Otherwise, we may override existing accumulator values from some tasks - if (allPartitions == partitionsToCompute) { + if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { stage.resetInternalAccumulators() } @@ -816,7 +820,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -845,13 +849,13 @@ class DAGScheduler( } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -878,7 +882,7 @@ class DAGScheduler( } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -1035,7 +1039,7 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head), changeEpoch = true) clearCacheLocs() @@ -1098,7 +1102,8 @@ class DAGScheduler( } if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1126,7 +1131,7 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1164,7 +1169,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { @@ -1235,7 +1240,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1244,7 +1252,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1252,8 +1260,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1372,33 +1383,36 @@ class DAGScheduler( return rddPrefs.map(TaskLocation(_)) } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. rdd.dependencies.foreach { case n: NarrowDependency[_] => - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } - case s: ShuffleDependency[_, _, _] => - // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION - // of data as preferred locations - if (shuffleLocalityEnabled && - rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && - s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { - // Get the preferred map output locations for this reducer - val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, - partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) - if (topLocsForReducer.nonEmpty) { - return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) - } - } - case _ => } + + // If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that + // have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations + if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) { + rdd.dependencies.foreach { + case s: ShuffleDependency[_, _, _] => + if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => + } + } Nil } @@ -1462,8 +1476,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a213d419cf03..f72a52e85dc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 66c75f325fcd..48d8d8e9c4b7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -37,7 +37,7 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id - var numAvailableOutputs: Long = 0 + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index de05ee256dbf..1cf06856ffbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -81,7 +81,7 @@ private[spark] abstract class Stage( * accumulators here again will override partial values from the finished tasks. */ def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create() + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5..818b95d67f6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -662,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,12 +707,15 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -728,16 +732,16 @@ private[spark] class TaskSetManager( logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6acf8a9a5e9b..5730a87f960a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -422,16 +422,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logWarning(s"Executor to kill $id does not exist!") } + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. if (!replace) { - doRequestTotalExecutors(numExistingExecutors + numPendingExecutors - - executorsPendingToRemove.size - knownExecutors.size) + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) } - executorsPendingToRemove ++= knownExecutors - doKillExecutors(knownExecutors) + doKillExecutors(executorsToKill) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 15a0915708c7..d6e1e9e5bebc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -194,6 +194,11 @@ private[spark] class CoarseMesosSchedulerBackend( s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index f078547e7135..1206f184fbc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -403,6 +403,9 @@ private[spark] class MesosClusterScheduler( } builder.setValue(s"$executable $cmdOptions $jar $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } builder.build() } @@ -504,14 +507,16 @@ private[spark] class MesosClusterScheduler( val driversToRetry = pendingRetryDrivers.filter { d => d.retryState.get.nextRetry.before(currentTime) } + scheduleTasks( - driversToRetry, + copyBuffer(driversToRetry), removeFromPendingRetryDrivers, currentOffers, tasks) + // Then we walk through the queued drivers and try to schedule them. scheduleTasks( - queuedDrivers, + copyBuffer(queuedDrivers), removeFromQueuedDrivers, currentOffers, tasks) @@ -524,13 +529,14 @@ private[spark] class MesosClusterScheduler( .foreach(o => driver.declineOffer(o.getId)) } + private def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + def getSchedulerState(): MesosClusterSchedulerState = { - def copyBuffer( - buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { - val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) - buffer.copyToBuffer(newBuffer) - newBuffer - } stateLock.synchronized { new MesosClusterSchedulerState( frameworkId, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3f63ec1c5832..5c20606d5871 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -133,6 +133,11 @@ private[spark] class MesosSchedulerBackend( builder.addAllResources(usedCpuResources) builder.addAllResources(usedMemResources) + + sc.conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index c04920e4f587..5b854aa5c275 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -331,4 +331,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.executorMemory } + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fae69551e733..d0163d326dba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -71,7 +71,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB /** * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockLocation to figure out where each block + * end of the output file. This will be used by getBlockData to figure out where each block * begins and ends. * */ def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 8c3a72644c38..a0d8abc2eecb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -175,7 +175,9 @@ private[spark] object ShuffleMemoryManager { val minPageSize = 1L * 1024 * 1024 // 1MB val maxPageSize = 64L * minPageSize // 64MB val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() - val safetyFactor = 8 + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + // TODO(davies): don't round to next power of 2 val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 86493673d958..fefaef0ab82c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -93,8 +93,17 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = - Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + private val externalShuffleServicePort = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get("spark.shuffle.service.port").toInt + } else { + tmpPort + } + } // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled @@ -191,6 +200,7 @@ private[spark] class BlockManager( executorId, blockTransferService.hostName, blockTransferService.port) shuffleServerId = if (externalShuffleServiceEnabled) { + logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId @@ -222,7 +232,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f70f701494db..f45bff34d4db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -69,8 +69,9 @@ class BlockManagerMaster( } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + GetLocationsMultipleBlockIds(blockIds)) } /** @@ -103,7 +104,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -115,7 +116,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -129,7 +130,7 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 5dc0c537cbb6..6fec5240707a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -372,7 +372,8 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds( + blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 56a33d5ca7d6..3f8d26e1d4ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -144,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -154,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -168,7 +168,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index ebad5bc5ab28..22878783fca6 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -32,7 +32,7 @@ import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -80,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -240,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c8356467fab8..779c0ba08359 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -106,7 +106,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -121,7 +125,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -246,11 +250,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39b..d8b90568b7b9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -64,11 +64,11 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 718aea7e1dc2..f2da41772410 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -352,7 +352,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
- + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0c94204df653..fb4556b83685 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -860,7 +860,7 @@ private[ui] class TaskDataSource( } val peakExecutionMemoryUsed = taskInternalAccumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index ffea9817c0b0..81f168a447ea 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo @@ -167,7 +167,7 @@ private[ui] object RDDOperationGraph extends Logging { def makeDotFile(graph: RDDOperationGraph): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") - dotFile.append(makeDotSubgraph(graph.rootCluster, indent = " ")) + makeDotSubgraph(dotFile, graph.rootCluster, indent = " ") graph.edges.foreach { edge => dotFile.append(s""" ${edge.fromId}->${edge.toId};\n""") } dotFile.append("}") val result = dotFile.toString() @@ -180,18 +180,19 @@ private[ui] object RDDOperationGraph extends Logging { s"""${node.id} [label="${node.name} [${node.id}]"]""" } - /** Return the dot representation of a subgraph in an RDDOperationGraph. */ - private def makeDotSubgraph(cluster: RDDOperationCluster, indent: String): String = { - val subgraph = new StringBuilder - subgraph.append(indent + s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent + s""" label="${cluster.name}";\n""") + /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ + private def makeDotSubgraph( + subgraph: StringBuilder, + cluster: RDDOperationCluster, + indent: String): Unit = { + subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") + subgraph.append(indent).append(s""" label="${cluster.name}";\n""") cluster.childNodes.foreach { node => - subgraph.append(indent + s" ${makeDotNode(node)};\n") + subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } cluster.childClusters.foreach { cscope => - subgraph.append(makeDotSubgraph(cscope, indent + " ")) + makeDotSubgraph(subgraph, cscope, indent + " ") } - subgraph.append(indent + "}\n") - subgraph.toString() + subgraph.append(indent).append("}\n") } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ebead830c646..150d82b3930e 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -181,7 +181,7 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClosureClasses(func) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c600319d9ddb..cbc94fd6d54d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -790,7 +790,7 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 000000000000..61ff9b89ec1c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Try(Utils.logUncaughtExceptions(hooks.poll().run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57e..724818724733 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4012d0e83f7..831331222671 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -65,21 +65,6 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. - */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 - /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -90,9 +75,6 @@ private[spark] object Utils extends Logging { @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -205,86 +187,6 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * JDK equivalent of `chmod 700 file`. * @@ -333,7 +235,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -973,9 +875,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1466,7 +1366,7 @@ private[spark] object Utils extends Logging { file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) - logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") } stringBuffer.toString } @@ -1478,27 +1378,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -2221,37 +2100,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2286,70 +2134,17 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } -} - -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.getBoolean("spark.dynamicAllocation.enabled", false) && + conf.getInt("spark.executor.instances", 0) == 0 } - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } - } - - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } - } - -} - -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } - - def run(): Unit = hook() - } /** diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 8fa72597db24..40fefe2c9d14 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -24,7 +24,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } @@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index e56a3f0b6d12..ab480b60adae 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -32,9 +32,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; +import org.apache.spark.unsafe.Platform; public abstract class AbstractBytesToBytesMapSuite { @@ -80,13 +78,8 @@ public void tearDown() { private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -108,7 +101,7 @@ private static boolean arrayEquals( long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), actualAddr.getBaseOffset(), expected.length @@ -124,7 +117,7 @@ public void emptyMap() { final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); @@ -141,14 +134,14 @@ public void setAndRetrieveAKey() { final byte[] valueData = getRandomByteArray(recordLengthWords); try { final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); // After storing the key and value, the other location methods should return results that @@ -159,7 +152,8 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); @@ -168,10 +162,10 @@ public void setAndRetrieveAKey() { try { Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); Assert.fail("Should not be able to set a new value for a key"); @@ -191,25 +185,25 @@ private void iteratorTestBase(boolean destructive) throws Exception { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( null, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 0, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } @@ -228,14 +222,13 @@ private void iteratorTestBase(boolean destructive) throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( + final long value = Platform.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -284,16 +277,16 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH )); } @@ -308,18 +301,18 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); for (long j : key) { @@ -354,16 +347,16 @@ public void randomizedStressTest() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -379,7 +372,8 @@ public void randomizedStressTest() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -405,16 +399,16 @@ public void randomizedTestWithRecordsLargerThanPageSize() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -429,7 +423,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -447,12 +442,10 @@ public void failureToAllocateFirstPage() { try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = - map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); Assert.assertFalse(loc.putNewKey( - emptyArray, LONG_ARRAY_OFFSET, 0, - emptyArray, LONG_ARRAY_OFFSET, 0 - )); + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); } @@ -468,8 +461,9 @@ public void failureToGrow() { int i; for (i = 0; i < 1024; i++) { final long[] arr = new long[]{i}; - final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); - success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -541,15 +535,15 @@ public void testPeakMemoryUsed() { try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -567,4 +561,13 @@ public void testPeakMemoryUsed() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 83049b8a21fc..445a37b83e98 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -49,7 +49,7 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -166,14 +166,14 @@ private void assertSpillFilesWereCleanedUp() { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -205,7 +205,7 @@ public void testSortingOnlyByPrefix() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); @@ -253,7 +253,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } sorter.cleanupResources(); @@ -265,7 +265,7 @@ public void testFillingPage() throws Exception { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -292,25 +292,25 @@ public void sortingRecordsThatExceedPageSize() throws Exception { iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Small record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); assertFalse(iter.hasNext()); sorter.cleanupResources(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539..778e813df6b5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,7 +26,7 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -36,11 +36,7 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, length); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -76,14 +72,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -113,7 +105,7 @@ public int compare(long prefix1, long prefix2) { position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 48f549575f4d..5b84acf40be4 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -160,7 +160,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("internal accumulators in TaskContext") { - val accums = InternalAccumulator.create() + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) val internalMetricsToAccums = taskContext.internalMetricsToAccumulators val collectedInternalAccums = taskContext.collectInternalAccumulators() @@ -181,26 +182,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext("local", "test") sc.addSparkListener(listener) // Have each task add 1 to the internal accumulator - sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() } test("internal accumulators in multiple stages") { @@ -210,7 +215,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.addSparkListener(listener) // Each stage creates its own set of internal accumulators so the // values for the same metric should not be mixed up across stages - sc.parallelize(1 to 100, numPartitions) + val rdd = sc.parallelize(1 to 100, numPartitions) .map { i => (i, i) } .mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 @@ -226,16 +231,20 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 iter } - .count() - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) - val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) - val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + rdd.count() } test("internal accumulators in fully resubmitted stages") { @@ -267,7 +276,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex // This says use 1 core and retry tasks up to 2 times sc = new SparkContext("local[1, 2]", "test") sc.addSparkListener(listener) - sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => val taskContext = TaskContext.get() taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 // Fail the first attempts of a subset of the tasks @@ -275,28 +284,32 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex throw new Exception("Failing a task intentionally.") } iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + rdd.count() } } @@ -312,20 +325,27 @@ private[spark] object AccumulatorSuite { testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Verify that the accumulator does not already exist + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { jobId => + if (jobId == 0) { + // The first job is a dummy one to verify that the accumulator does not already exist + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + } else { + // In the subsequent jobs, verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } + } + // Run the jobs sc.parallelize(1 to 10).count() - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) testBody - // Verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) } } @@ -335,10 +355,22 @@ private[spark] object AccumulatorSuite { private class SaveInfoListener extends SparkListener { private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + /** Register a callback to be called on job end. */ + def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + jobCompletionCallback = callback + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (jobCompletionCallback != null) { + jobCompletionCallback(jobEnd.jobId) + } + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { completedStageInfos += stageCompleted.stageInfo } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 34caca892891..116f027a0f98 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -787,6 +787,24 @@ class ExecutorAllocationManagerSuite Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index c38d70252add..e846a72c888c 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -36,7 +36,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) - rpcHandler = new ExternalShuffleBlockHandler(transportConf) + rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 69cb4b44cf7e..aa50a49c5023 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) } + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5c57940fa5f7..d4f2ea87650a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -285,4 +285,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 757e0ce3d278..1110ca6051a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -159,7 +159,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") @@ -325,6 +324,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -338,6 +339,8 @@ class SparkSubmitSuite "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -356,6 +359,7 @@ class SparkSubmitSuite "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -501,6 +505,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 08c41a897a86..1f2a0f0d309c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -283,6 +283,26 @@ class StandaloneDynamicAllocationSuite assert(master.apps.head.getExecutorLimit === 1000) } + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + assert(master.apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(master.apps.head.getExecutorLimit === 1) + } + // =============================== // | Utility methods for testing | // =============================== diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 20d0201a364a..242bf4b5566e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -40,6 +40,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86dff8fb577d..2e8688cf41d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -133,11 +133,11 @@ class DAGSchedulerSuite val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null, conf, true) { - override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). getOrElse(Seq()) - }.toSeq + }.toIndexedSeq } override def removeExecutor(execId: String) { // don't need to propagate to the driver, which we don't have @@ -242,7 +242,7 @@ class DAGSchedulerSuite /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ @@ -926,7 +926,7 @@ class DAGSchedulerSuite assertLocations(reduceTaskSet, Seq(Seq("hostA"))) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() } test("reduce task locality preferences should only include machines with largest map outputs") { @@ -950,7 +950,29 @@ class DAGSchedulerSuite assertLocations(reduceTaskSet, Seq(hosts)) complete(reduceTaskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) - assertDataStructuresEmpty + assertDataStructuresEmpty() + } + + test("stages with both narrow and shuffle dependencies use narrow ones for locality") { + // Create an RDD that has both a shuffle dependency and a narrow dependency (e.g. for a join) + val rdd1 = new MyRDD(sc, 1, Nil) + val rdd2 = new MyRDD(sc, 1, Nil, locations = Seq(Seq("hostB"))) + val shuffleDep = new ShuffleDependency(rdd1, null) + val narrowDep = new OneToOneDependency(rdd2) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep, narrowDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) + + // Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostB"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() } test("Spark exceptions should include call site in stack trace") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f7cc4bb61d57..edbdb485c5ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 3aa672f8b713..69888b2694ba 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import scala.io.Source import scala.collection.JavaConversions._ import scala.xml.Node @@ -603,6 +604,44 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + test("job stages should have expected dotfile under DAG visualization") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + rdd.count() + + val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]"];\n }")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]"];\n }")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]"];\n }")) + + val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]"];\n }")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]"];\n }")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]"];\n }")) + + val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]"];\n }")) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f35..b140387d309f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index dde95f377843..343a4139b0ca 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -163,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite { } test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8f7e402d5f2a..1fb81ad565b4 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -720,4 +720,18 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) } + + test("isDynamicAllocationEnabled") { + val conf = new SparkConf() + assert(Utils.isDynamicAllocationEnabled(conf) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "false")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "true")) === true) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "1")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "0")) === true) + } + } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh deleted file mode 100755 index 4311c8c9e4ca..000000000000 --- a/dev/create-release/create-release.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Quick-and-dirty automation of making maven and binary releases. Not robust at all. -# Publishes releases to Maven and packages/copies binary release artifacts. -# Expects to be run in a totally empty directory. -# -# Options: -# --skip-create-release Assume the desired release tag already exists -# --skip-publish Do not publish to Maven central -# --skip-package Do not package and upload binary artifacts -# Would be nice to add: -# - Send output to stderr and have useful logging in stdout - -# Note: The following variables must be set before use! -ASF_USERNAME=${ASF_USERNAME:-pwendell} -ASF_PASSWORD=${ASF_PASSWORD:-XXX} -GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} -GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} -# Allows publishing under a different version identifier than -# was present in the actual release sources (e.g. rc-X) -PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} -NEXT_VERSION=${NEXT_VERSION:-1.2.1} -RC_NAME=${RC_NAME:-rc2} - -M2_REPO=~/.m2/repository -SPARK_REPO=$M2_REPO/org/apache/spark -NEXUS_ROOT=https://repository.apache.org/service/local/staging -NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi -JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} - -set -e - -GIT_TAG=v$RELEASE_VERSION-$RC_NAME - -if [[ ! "$@" =~ --skip-create-release ]]; then - echo "Creating release commit and publishing to Apache repository" - # Artifact publishing - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ - -b $GIT_BRANCH - pushd spark - export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - - # Create release commits and push them to github - # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build - # or before we coin the release commit. This helps avoid races where - # other people add commits to this branch while we are in the middle of building. - cur_ver="${RELEASE_VERSION}-SNAPSHOT" - rel_ver="${RELEASE_VERSION}" - next_ver="${NEXT_VERSION}-SNAPSHOT" - - old="^\( \{2,4\}\)${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.11 - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.10 - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 000000000000..d0b3a54dde1d --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,321 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.11 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES clean deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 000000000000..b0a3374becc6 --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ad4b76695c9f..b9bdec3d7086 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -159,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) diff --git a/dev/run-tests.py b/dev/run-tests.py index d1852b95bb29..f689425ee40b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -303,6 +303,8 @@ def build_spark_sbt(hadoop_version): "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a9717ff9569c..346452f3174e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -134,7 +134,7 @@ def contains_file(self, filename): # files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't # fail other PRs. streaming_kinesis_asl = Module( - name="kinesis-asl", + name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", @@ -147,7 +147,7 @@ def contains_file(self, filename): "ENABLE_KINESIS_TESTS": "1" }, sbt_test_goals=[ - "kinesis-asl/test", + "streaming-kinesis-asl/test", ] ) @@ -181,6 +181,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -306,6 +307,7 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, + streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ diff --git a/docs/configuration.md b/docs/configuration.md index c60dd16839c0..4a6e4dd05b66 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -389,7 +389,8 @@ Apart from these, the following properties are also available, and may be useful Implementation to use for transferring shuffle and cached blocks between executors. There are two implementations available: netty and nio. Netty-based block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + starting in 1.2, and nio block transfer is deprecated in Spark 1.5.0 and will + be removed in Spark 1.6.0. @@ -1560,7 +1561,11 @@ The following variables can be set in `spark-env.sh`: PYSPARK_PYTHON - Python binary executable to use for PySpark. + Python binary executable to use for PySpark in both driver and workers (default is `python`). + + + PYSPARK_DRIVER_PYTHON + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 99f8c827f767..c861a763d622 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -768,16 +768,14 @@ class GraphOps[VD, ED] { // Loop until no messages remain or maxIterations is achieved var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages: ----------------------------------------------------------------------- - // Run the vertex program on all vertices that receive messages - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Merge the new vertex values back into the graph - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() - // Send Messages: ------------------------------------------------------------------------------ - // Vertices that didn't receive a message above don't appear in newVerts and therefore don't - // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked - // on edges in the activeDir of vertices in newVerts - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDir))).cache() + // Receive the messages and update the vertices. + g = g.joinVertices(messages)(vprog).cache() + val oldMessages = messages + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/docs/ml-ann.md b/docs/ml-ann.md new file mode 100644 index 000000000000..d5ddd92af1e9 --- /dev/null +++ b/docs/ml-ann.md @@ -0,0 +1,123 @@ +--- +layout: global +title: Multilayer perceptron classifier - ML +displayTitle: ML - Multilayer perceptron classifier +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Examples** + +
+ +
+ +{% highlight scala %} +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row + +// Load training data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt").toDF() +// Split the data into train and test +val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) +val train = splits(0) +val test = splits(1) +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +val layers = Array[Int](4, 5, 4, 3) +// create the trainer and set its parameters +val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) +// train the model +val model = trainer.fit(train) +// compute precision on the test set +val result = model.transform(test) +val predictionAndLabels = result.select("prediction", "label") +val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") +println("Precision:" + evaluator.evaluate(predictionAndLabels)) +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +// Load training data +String path = "data/mllib/sample_multiclass_classification_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); +DataFrame dataFrame = sqlContext.createDataFrame(data, LabeledPoint.class); +// Split the data into train and test +DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); +DataFrame train = splits[0]; +DataFrame test = splits[1]; +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +int[] layers = new int[] {4, 5, 4, 3}; +// create the trainer and set its parameters +MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); +// train the model +MultilayerPerceptronClassificationModel model = trainer.fit(train); +// compute precision on the test set +DataFrame result = model.transform(test); +DataFrame predictionAndLabels = result.select("prediction", "label"); +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); +System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); +{% endhighlight %} +
+ +
diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md new file mode 100644 index 000000000000..958c6f5e4716 --- /dev/null +++ b/docs/ml-decision-tree.md @@ -0,0 +1,510 @@ +--- +layout: global +title: Decision Trees - SparkML +displayTitle: ML - Decision Trees +--- + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + + +# Overview + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). In this section, we demonstrate the Pipelines API for Decision Trees. + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described in the [Ensembles guide](ml-ensembles.html). + +# Inputs and Outputs (Predictions) + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +## Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescription
labelColDouble"label"Label to predict
featuresColVector"features"Feature vector
+ +## Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescriptionNotes
predictionColDouble"prediction"Predicted label
rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
+ +# Examples + +The below examples demonstrate the Pipelines API for Decision Trees. The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +## Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
+
+ +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a DecisionTree model. +val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and tree in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] +println("Learned classification tree model:\n" + treeModel.toDebugString) +{% endhighlight %} +
+ +
+ +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a DecisionTree model. +DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and tree in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel)(model.stages()[2]); +System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); +{% endhighlight %} +
+ +
+ +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# We specify maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a DecisionTree model. +dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + +# Chain indexers and tree in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +treeModel = model.stages[2] +print treeModel # summary only +{% endhighlight %} +
+ +
+ + +## Regression + +
+
+ +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Automatically identify categorical features, and index them. +// Here, we treat features with > 4 distinct values as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a DecisionTree model. +val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + +// Chain indexers and tree in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +// We negate the RMSE value since RegressionEvalutor returns negated RMSE +// (since evaluation metrics are meant to be maximized by CrossValidator). +val rmse = - evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] +println("Learned regression tree model:\n" + treeModel.toDebugString) +{% endhighlight %} +
+ +
+ +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a DecisionTree model. +DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and tree in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("rmse"); +// We negate the RMSE value since RegressionEvalutor returns negated RMSE +// (since evaluation metrics are meant to be maximized by CrossValidator). +double rmse = - evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel)(model.stages()[2]); +System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); +{% endhighlight %} +
+ +
+ +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# We specify maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a DecisionTree model. +dt = DecisionTreeRegressor(labelCol="indexedLabel", featuresCol="indexedFeatures") + +# Chain indexers and tree in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="rmse") +# We negate the RMSE value since RegressionEvalutor returns negated RMSE +# (since evaluation metrics are meant to be maximized by CrossValidator). +rmse = -evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +treeModel = model.stages[1] +print treeModel # summary only +{% endhighlight %} +
+ +
diff --git a/docs/ml-features.md b/docs/ml-features.md index fa0ad1f00ab1..642a4b4c5318 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -55,7 +55,7 @@ rescaledData.select("features", "label").take(3).foreach(println)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.HashingTF; @@ -70,7 +70,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "Hi I heard about Spark"), RowFactory.create(0, "I wish Java could use case classes"), RowFactory.create(1, "Logistic regression models are neat") @@ -153,7 +153,7 @@ result.select("result").take(3).foreach(println)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -167,10 +167,10 @@ JavaSparkContext jsc = ... SQLContext sqlContext = ... // Input data: Each row is a bag of words from a sentence or document. -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) @@ -217,29 +217,41 @@ for feature in result.select("result").take(3): [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
{% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} val sentenceDataFrame = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) {% endhighlight %}
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -250,10 +262,10 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -267,22 +279,29 @@ for (Row r : wordsDataFrame.select("words", "label").take(3)) { for (String word : words) System.out.print(word + " "); System.out.println(); } + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); {% endhighlight %}
{% highlight python %} -from pyspark.ml.feature import Tokenizer +from pyspark.ml.feature import Tokenizer, RegexTokenizer sentenceDataFrame = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) {% endhighlight %}
@@ -322,7 +341,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri [`NGram`](api/java/org/apache/spark/ml/feature/NGram.html) takes an input column name, an output column name, and an optional length parameter n (n=2 by default). {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.NGram; @@ -335,10 +354,10 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0D, Lists.newArrayList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1D, Lists.newArrayList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2D, Lists.newArrayList("Logistic", "regression", "models", "are", "neat")) +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -408,7 +427,7 @@ binarizedFeatures.collect().foreach(println)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Binarizer; @@ -420,7 +439,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, 0.1), RowFactory.create(1, 0.8), RowFactory.create(2, 0.2) @@ -492,7 +511,7 @@ result.show()
See the [Java API documentation](api/java/org/apache/spark/ml/feature/PCA.html) for API details. {% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -510,7 +529,7 @@ import org.apache.spark.sql.types.StructType; JavaSparkContext jsc = ... SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Lists.newArrayList( +JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) @@ -574,7 +593,7 @@ polyDF.select("polyFeatures").take(3).foreach(println)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -595,7 +614,7 @@ PolynomialExpansion polyExpansion = new PolynomialExpansion() .setInputCol("features") .setOutputCol("polyFeatures") .setDegree(3); -JavaRDD data = jsc.parallelize(Lists.newArrayList( +JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.dense(-2.0, 2.3)), RowFactory.create(Vectors.dense(0.0, 0.0)), RowFactory.create(Vectors.dense(0.6, -1.1)) @@ -630,12 +649,87 @@ for expanded in polyDF.select("polyFeatures").take(3):
+## Discrete Cosine Transform (DCT) + +The [Discrete Cosine +Transform](https://en.wikipedia.org/wiki/Discrete_cosine_transform) +transforms a length $N$ real-valued sequence in the time domain into +another length $N$ real-valued sequence in the frequency domain. A +[DCT](api/scala/index.html#org.apache.spark.ml.feature.DCT) class +provides this functionality, implementing the +[DCT-II](https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II) +and scaling the result by $1/\sqrt{2}$ such that the representing matrix +for the transform is unitary. No shift is applied to the transformed +sequence (e.g. the $0$th element of the transformed sequence is the +$0$th DCT coefficient and _not_ the $N/2$th). + +
+
+{% highlight scala %} +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors + +val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) +val dctDf = dct.transform(df) +dctDf.select("featuresDCT").show(3) +{% endhighlight %} +
+ +
+{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); +DataFrame dctDf = dct.transform(df); +dctDf.select("featuresDCT").show(3); +{% endhighlight %} +
+
+ ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string values. +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, +you can set the input column with `setInputCol`. **Examples** @@ -779,7 +873,7 @@ encoded.select("id", "categoryVec").foreach(println)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.OneHotEncoder; @@ -793,7 +887,7 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0, "a"), RowFactory.create(1, "b"), RowFactory.create(2, "c"), @@ -1116,7 +1210,7 @@ val bucketedData = bucketizer.transform(dataFrame)
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; @@ -1128,7 +1222,7 @@ import org.apache.spark.sql.types.StructType; double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; -JavaRDD data = jsc.parallelize(Lists.newArrayList( +JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), @@ -1193,7 +1287,7 @@ v_N This example below demonstrates how to transform vectors using a transforming vector value.
-
+
{% highlight scala %} import org.apache.spark.ml.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors @@ -1210,14 +1304,14 @@ val transformer = new ElementwiseProduct() .setOutputCol("transformedVector") // Batch transform the vectors to create new column: -val transformedData = transformer.transform(dataFrame) +transformer.transform(dataFrame).show() {% endhighlight %}
-
+
{% highlight java %} -import com.google.common.collect.Lists; +import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.ElementwiseProduct; @@ -1233,7 +1327,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; // Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( +JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) )); @@ -1248,10 +1342,25 @@ ElementwiseProduct transformer = new ElementwiseProduct() .setInputCol("vector") .setOutputCol("transformedVector"); // Batch transform the vectors to create new column: -DataFrame transformedData = transformer.transform(dataFrame); +transformer.transform(dataFrame).show(); + +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] +df = sqlContext.createDataFrame(data, ["vector"]) +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") +transformer.transform(df).show() {% endhighlight %}
+
## VectorAssembler @@ -1370,3 +1479,242 @@ print(output.select("features", "clicked").first()) # Feature Selectors +## VectorSlicer + +`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a +sub-array of the original features. It is useful for extracting features from a vector column. + +`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column +whose values are selected via those indices. There are two types of indices, + + 1. Integer indices that represents the indices into the vector, `setIndices()`; + + 2. String indices that represents the names of features into the vector, `setNames()`. + *This requires the vector column to have an `AttributeGroup` since the implementation matches on + the name field of an `Attribute`.* + +Specification by integer and string are both acceptable. Moreover, you can use integer index and +string name simultaneously. At least one feature must be selected. Duplicate features are not +allowed, so there can be no overlap between selected indices and names. Note that if names of +features are selected, an exception will be threw out when encountering with empty input attributes. + +The output vector will order features with the selected indices first (in the order given), +followed by the selected names (in the order given). + +**Examples** + +Suppose that we have a DataFrame with the column `userFeatures`: + +~~~ + userFeatures +------------------ + [0.0, 10.0, 0.5] +~~~ + +`userFeatures` is a vector column that contains three user features. Assuming that the first column +of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected. +The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector +column named `features`: + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] +~~~ + +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] + ["f1", "f2", "f3"] | ["f2", "f3"] +~~~ + +
+
+ +[`VectorSlicer`](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) takes an input +column name with specified indices or names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %} +
+ +
+ +[`VectorSlicer`](api/java/org/apache/spark/ml/feature/VectorSlicer.html) takes an input column name +with specified indices or names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %} +
+
+ +## RFormula + +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: + +~~~ +id | country | hour | clicked +---|---------|------|--------- + 7 | "US" | 18 | 1.0 + 8 | "CA" | 12 | 0.0 + 9 | "NZ" | 15 | 0.0 +~~~ + +If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to +predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame: + +~~~ +id | country | hour | clicked | features | label +---|---------|------|---------|------------------|------- + 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0 + 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0 + 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0 +~~~ + +
+
+ +[`RFormula`](api/scala/index.html#org.apache.spark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
+ +
+ +[`RFormula`](api/java/org/apache/spark/ml/feature/RFormula.html) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %} +
+ +
+ +[`RFormula`](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
+
diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e98db0..de8fead3529e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -32,10 +32,7 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -Guides for sub-packages of `spark.ml` include: - -* [Feature Extraction, Transformation, and Selection](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API +See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of `spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. **Table of Contents** @@ -179,11 +176,10 @@ There are now several algorithms in the Pipelines API which are not in the lower **Pipelines API Algorithm Guides** * [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Decision Trees for Classification and Regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) - -**Algorithms in `spark.ml`** - * [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) # Code Examples @@ -274,8 +270,9 @@ sc.stop()
{% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; + import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionModel; @@ -294,7 +291,7 @@ SQLContext jsql = new SQLContext(jsc); // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. -List localTraining = Lists.newArrayList( +List localTraining = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), @@ -335,7 +332,7 @@ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. -List localTest = Lists.newArrayList( +List localTest = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); @@ -355,6 +352,74 @@ jsc.stop(); {% endhighlight %}
+
+{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleParamsExample") +sqlContext = SQLContext(sc) + +# Prepare training data. +# We use LabeledPoint. +# Spark SQL can convert RDDs of LabeledPoints into DataFrames. +training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), + LabeledPoint(0.0, [2.0, 1.0, -1.0]), + LabeledPoint(0.0, [2.0, 1.3, 1.0]), + LabeledPoint(1.0, [0.0, 1.2, -0.5])]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training.toDF()) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training.toDF(), paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), + LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), + LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test.toDF()) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +sc.stop() +{% endhighlight %} +
+
## Example: Pipeline @@ -428,8 +493,9 @@ sc.stop()
{% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; + import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; @@ -478,7 +544,7 @@ JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +List localTraining = Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -503,7 +569,7 @@ Pipeline pipeline = new Pipeline() PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +List localTest = Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), @@ -679,8 +745,9 @@ sc.stop()
{% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; + import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; @@ -732,7 +799,7 @@ JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Lists.newArrayList( +List localTraining = Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -781,7 +848,7 @@ crossval.setNumFolds(2); // Use 3+ in practice CrossValidatorModel cvModel = crossval.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Lists.newArrayList( +List localTest = Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index bb875ae2ae6c..fd9ab258e196 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -564,6 +564,34 @@ public class JavaLDAExample { {% endhighlight %}
+
+{% highlight python %} +from pyspark.mllib.clustering import LDA, LDAModel +from pyspark.mllib.linalg import Vectors + +# Load and parse the data +data = sc.textFile("data/mllib/sample_lda_data.txt") +parsedData = data.map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) +# Index documents with unique IDs +corpus = parsedData.zipWithIndex().map(lambda x: [x[1], x[0]]).cache() + +# Cluster the documents into three topics using LDA +ldaModel = LDA.train(corpus, k=3) + +# Output topics. Each is a distribution over words (matching word count vectors) +print("Learned topics (as distributions over vocab of " + str(ldaModel.vocabSize()) + " words):") +topics = ldaModel.topicsMatrix() +for topic in range(3): + print("Topic " + str(topic) + ":") + for word in range(0, ldaModel.vocabSize()): + print(" " + str(topics[word][topic])) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LDAModel.load(sc, "myModelPath") +{% endhighlight %} +
+
## Streaming k-means diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 7521fb14a7bd..1e00b2083ed7 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBosotedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index bcc066a18552..4d4f5cfdc564 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -41,16 +41,23 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters: [`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. -It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. +It take a `RDD` of transactions, where each transaction is an `Array` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) -that stores the frequent itemsets with their frequencies. +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. + {% highlight scala %} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.mllib.fpm.FPGrowth + +val data = sc.textFile("data/mllib/sample_fpgrowth.txt") -val transactions: RDD[Array[String]] = ... +val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) val fpg = new FPGrowth() .setMinSupport(0.2) @@ -60,6 +67,14 @@ val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) } + +val minConfidence = 0.8 +model.generateAssociationRules(minConfidence).collect().foreach { rule => + println( + rule.antecedent.mkString("[", ",", "]") + + " => " + rule.consequent .mkString("[", ",", "]") + + ", " + rule.confidence) +} {% endhighlight %}
@@ -68,21 +83,38 @@ model.freqItemsets.collect().foreach { itemset => [`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the FP-growth algorithm. -It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type. +It take an `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an [`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) -that stores the frequent itemsets with their frequencies. +that stores the frequent itemsets with their frequencies. The following +example illustrates how to mine frequent itemsets and association rules +(see [Association +Rules](mllib-frequent-pattern-mining.html#association-rules) for +details) from `transactions`. {% highlight java %} +import java.util.Arrays; import java.util.List; -import com.google.common.base.Joiner; - import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; import org.apache.spark.mllib.fpm.FPGrowth; import org.apache.spark.mllib.fpm.FPGrowthModel; -JavaRDD> transactions = ... +SparkConf conf = new SparkConf().setAppName("FP-growth Example"); +JavaSparkContext sc = new JavaSparkContext(conf); + +JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); + +JavaRDD> transactions = data.map( + new Function>() { + public List call(String line) { + String[] parts = line.split(" "); + return Arrays.asList(parts); + } + } +); FPGrowth fpg = new FPGrowth() .setMinSupport(0.2) @@ -90,9 +122,202 @@ FPGrowth fpg = new FPGrowth() FPGrowthModel model = fpg.run(transactions); for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); + System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); +} + +double minConfidence = 0.8; +for (AssociationRules.Rule rule + : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); } {% endhighlight %}
+ +
+ +[`FPGrowth`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) implements the +FP-growth algorithm. +It take an `RDD` of transactions, where each transaction is an `List` of items of a generic type. +Calling `FPGrowth.train` with transactions returns an +[`FPGrowthModel`](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowthModel) +that stores the frequent itemsets with their frequencies. + +{% highlight python %} +from pyspark.mllib.fpm import FPGrowth + +data = sc.textFile("data/mllib/sample_fpgrowth.txt") + +transactions = data.map(lambda line: line.strip().split(' ')) + +model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) + +result = model.freqItemsets().collect() +for fi in result: + print(fi) +{% endhighlight %} +
+ +
+ +## Association Rules + +
+
+[AssociationRules](api/scala/index.html#org.apache.spark.mllib.fpm.AssociationRules) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight scala %} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.fpm.AssociationRules +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset + +val freqItemsets = sc.parallelize(Seq( + new FreqItemset(Array("a"), 15L), + new FreqItemset(Array("b"), 35L), + new FreqItemset(Array("a", "b"), 12L) +)); + +val ar = new AssociationRules() + .setMinConfidence(0.8) +val results = ar.run(freqItemsets) + +results.collect().foreach { rule => + println("[" + rule.antecedent.mkString(",") + + "=>" + + rule.consequent.mkString(",") + "]," + rule.confidence) +} +{% endhighlight %} + +
+ +
+[AssociationRules](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) +implements a parallel rule generation algorithm for constructing rules +that have a single item as the consequent. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + +JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) +)); + +AssociationRules arules = new AssociationRules() + .setMinConfidence(0.8); +JavaRDD> results = arules.run(freqItemsets); + +for (AssociationRules.Rule rule: results.collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); +} +{% endhighlight %} + +
+
+ +## PrefixSpan + +PrefixSpan is a sequential pattern mining algorithm described in +[Pei et al., Mining Sequential Patterns by Pattern-Growth: The +PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer +the reader to the referenced paper for formalizing the sequential +pattern mining problem. + +MLlib's PrefixSpan implementation takes the following parameters: + +* `minSupport`: the minimum support required to be considered a frequent + sequential pattern. +* `maxPatternLength`: the maximum length of a frequent sequential + pattern. Any frequent pattern exceeding this length will not be + included in the results. +* `maxLocalProjDBSize`: the maximum number of items allowed in a + prefix-projected database before local iterative processing of the + projected databse begins. This parameter should be tuned with respect + to the size of your executors. + +**Examples** + +The following example illustrates PrefixSpan running on the sequences +(using same notation as Pei et al): + +~~~ + <(12)3> + <1(32)(12)> + <(12)5> + <6> +~~~ + +
+
+ +[`PrefixSpan`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) +that stores the frequent sequences with their frequencies. + +{% highlight scala %} +import org.apache.spark.mllib.fpm.PrefixSpan + +val sequences = sc.parallelize(Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6)) + ), 2).cache() +val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) +val model = prefixSpan.run(sequences) +model.freqSequences.collect().foreach { freqSequence => +println( + freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + ", " + freqSequence.freq) +} +{% endhighlight %} + +
+ +
+ +[`PrefixSpan`](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) implements the +PrefixSpan algorithm. +Calling `PrefixSpan.run` returns a +[`PrefixSpanModel`](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) +that stores the frequent sequences with their frequencies. + +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.mllib.fpm.PrefixSpan; +import org.apache.spark.mllib.fpm.PrefixSpanModel; + +JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) +), 2); +PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); +PrefixSpanModel model = prefixSpan.run(sequences); +for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { + System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); +} +{% endhighlight %} + +
+
+ diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index eea864eacf7c..6330c977552d 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -23,19 +23,19 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) - * summary statistics - * correlations - * stratified sampling - * hypothesis testing - * random data generation + * [summary statistics](mllib-statistics.html#summary-statistics) + * [correlations](mllib-statistics.html#correlations) + * [stratified sampling](mllib-statistics.html#stratified-sampling) + * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) - * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [ensembles of trees (Random Forests and Gradient-Boosted Trees)](mllib-ensembles.html) * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) - * alternating least squares (ALS) + * [alternating least squares (ALS)](mllib-collaborative-filtering.html#collaborative-filtering) * [Clustering](mllib-clustering.html) * [k-means](mllib-clustering.html#k-means) * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) @@ -43,19 +43,21 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) - * singular value decomposition (SVD) - * principal component analysis (PCA) + * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) + * [principal component analysis (PCA)](mllib-dimensionality-reduction.html#principal-component-analysis-pca) * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) - * FP-growth + * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) + * [association rules](mllib-frequent-pattern-mining.html#association-rules) + * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) * [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) - * stochastic gradient descent - * limited-memory BFGS (L-BFGS) + * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) + * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) * [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. # spark.ml: high-level APIs for ML pipelines @@ -71,11 +73,14 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -More detailed guides for `spark.ml` include: +Guides for `spark.ml` include: * **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts -* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API -* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API +* Guides on using algorithms within the Pipelines API: + * [Feature transformers](ml-features.html), including a few not in the lower-level `spark.mllib` API + * [Decision trees](ml-decision-tree.html) + * [Ensembles](ml-ensembles.html) + * [Linear methods](ml-linear-methods.html) # Dependencies diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 5732bc4c7e79..6aa881f74918 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -160,4 +160,39 @@ model.save(sc.sc(), "myModelPath"); IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
+ +
+Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight python %} +import math +from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel + +data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +# Create label, feature, weight tuples from input data with weight set to default value 1.0. +parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + +# Split data into training (60%) and test (40%) sets. +training, test = parsedData.randomSplit([0.6, 0.4], 11) + +# Create isotonic regression model from training data. +# Isotonic parameter defaults to true so it is only shown for demonstration +model = IsotonicRegression.train(training) + +# Create tuples of predicted and real labels. +predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) + +# Calculate mean squared error between predicted and real labels. +meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() +print("Mean Squared Error = " + str(meanSquaredError)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = IsotonicRegressionModel.load(sc, "myModelPath") +{% endhighlight %} +
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 07655baa414b..e9b2d276cd38 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -504,7 +504,6 @@ will in the future. {% highlight python %} from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -676,7 +675,6 @@ Note that the Python API does not yet support model save/load but will in the fu {% highlight python %} from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel -from numpy import array # Load and parse the data def parsePoint(line): diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index be04d0b4b53a..6acfc71d7b01 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -438,22 +438,65 @@ run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstra and interpret the hypothesis tests. {% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.stat.Statistics._ +import org.apache.spark.mllib.stat.Statistics val data: RDD[Double] = ... // an RDD of sample data // run a KS test for the sample versus a standard normal distribution val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) println(testResult) // summary of the test including the p-value, test statistic, - // and null hypothesis - // if our p-value indicates significance, we can reject the null hypothesis + // and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis // perform a KS test using a cumulative distribution function of our making val myCDF: Double => Double = ... val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) {% endhighlight %}
+ +
+[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; + +JavaSparkContext jsc = ... +JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, ...)); +KolmogorovSmirnovTestResult testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); +// summary of the test including the p-value, test statistic, +// and null hypothesis +// if our p-value indicates significance, we can reject the null hypothesis +System.out.println(testResult); +{% endhighlight %} +
+ +
+[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to +run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run +and interpret the hypothesis tests. + +{% highlight python %} +from pyspark.mllib.stat import Statistics + +parallelData = sc.parallelize([1.0, 2.0, ... ]) + +# run a KS test for the sample versus a standard normal distribution +testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) +print(testResult) # summary of the test including the p-value, test statistic, + # and null hypothesis + # if our p-value indicates significance, we can reject the null hypothesis +# Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with +# a lambda to calculate the CDF is not made available in the Python API +{% endhighlight %} +
@@ -528,5 +571,82 @@ u = RandomRDDs.uniformRDD(sc, 1000000L, 10) v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} + + +## Kernel density estimation + +[Kernel density estimation](https://en.wikipedia.org/wiki/Kernel_density_estimation) is a technique +useful for visualizing empirical probability distributions without requiring assumptions about the +particular distribution that the observed samples are drawn from. It computes an estimate of the +probability density function of a random variables, evaluated at a given set of points. It achieves +this estimate by expressing the PDF of the empirical distribution at a particular point as the the +mean of PDFs of normal distributions centered around each of the samples. + +
+ +
+[`KernelDensity`](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight scala %} +import org.apache.spark.mllib.stat.KernelDensity +import org.apache.spark.rdd.RDD + +val data: RDD[Double] = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +val kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0) + +// Find density estimates for the given values +val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) +{% endhighlight %} +
+ +
+[`KernelDensity`](api/java/index.html#org.apache.spark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight java %} +import org.apache.spark.mllib.stat.KernelDensity; +import org.apache.spark.rdd.RDD; + +RDD data = ... // an RDD of sample data + +// Construct the density estimator with the sample data and a standard deviation for the Gaussian +// kernels +KernelDensity kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0); + +// Find density estimates for the given values +double[] densities = kd.estimate(new double[] {-1.0, 2.0, 5.0}); +{% endhighlight %} +
+ +
+[`KernelDensity`](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) provides methods +to compute kernel density estimates from an RDD of samples. The following example demonstrates how +to do so. + +{% highlight python %} +from pyspark.mllib.stat import KernelDensity + +data = ... # an RDD of sample data + +# Construct the density estimator with the sample data and a standard deviation for the Gaussian +# kernels +kd = KernelDensity() +kd.setSample(data) +kd.setBandwidth(3.0) + +# Find density estimates for the given values +densities = kd.estimate([-1.0, 2.0, 5.0]) +{% endhighlight %} +
diff --git a/docs/programming-guide.md b/docs/programming-guide.md index ae712d62746f..4cf83bb39263 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -85,8 +85,8 @@ import org.apache.spark.SparkConf
-Spark {{site.SPARK_VERSION}} works with Python 2.6 or higher (but not Python 3). It uses the standard CPython interpreter, -so C libraries like NumPy can be used. +Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, +so C libraries like NumPy can be used. It also works with PyPy 2.3+. To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. @@ -104,6 +104,14 @@ Finally, you need to import some Spark classes into your program. Add the follow from pyspark import SparkContext, SparkConf {% endhighlight %} +PySpark requires the same minor version of Python in both driver and workers. It uses the default python version in PATH, +you can specify which version of Python you want to use by `PYSPARK_PYTHON`, for example: + +{% highlight bash %} +$ PYSPARK_PYTHON=python3.4 bin/pyspark +$ PYSPARK_PYTHON=/opt/pypy-2.5/bin/pypy bin/spark-submit examples/src/main/python/pi.py +{% endhighlight %} +
@@ -541,7 +549,7 @@ returning only its answer to the driver program. If we also wanted to use `lineLengths` again later, we could add: {% highlight java %} -lineLengths.persist(); +lineLengths.persist(StorageLevel.MEMORY_ONLY()); {% endhighlight %} before the `reduce`, which would cause `lineLengths` to be saved in memory after the first time it is computed. diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index debdd2adf22d..cfd219ab02e2 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -216,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration @@ -306,6 +320,14 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + spark.mesos.principal Framework principal to authenticate to Mesos diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d..1400ae287dcb 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -21,32 +21,51 @@ There are two deploy modes that can be used to launch Spark applications on YARN Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. To launch a Spark application in `yarn-cluster` mode: - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` - + `$ ./bin/spark-submit --class path.to.your.Class --master yarn --deploy-mode yarn-client/yarn-cluster [options] [app options]` + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ + --master yarn \ + --deploy-mode cluster --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ --queue thequeue \ lib/spark-examples*.jar \ - 10 + +`--deploy-mode` can be either client or cluster. -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. +The above example starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client` in the --deploy-mode. To run spark-shell: $ ./bin/spark-shell --master yarn-client + +The alternative to launching a Spark application on YARN is to set deployment mode for the YARN master in the `--master` itself. + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +`--master` can be `yarn-client` or `yarn-cluster` ## Adding Other JARs In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ + --master yarn + --deploy-mode cluster \ --jars my-other-jar.jar,my-other-other-jar.jar my-main-jar.jar app_arg1 app_arg2 @@ -199,7 +218,7 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. @@ -319,6 +338,14 @@ If you need a reference to the proper location to put log files in the YARN so t running against earlier versions, this property will be ignored. + + spark.yarn.tags + (none) + + Comma-separated list of strings to pass through as YARN application tags appearing + in YARN ApplicationReports, which can be used for filtering when querying YARN apps. + + spark.yarn.keytab (none) @@ -361,11 +388,23 @@ If you need a reference to the proper location to put log files in the YARN so t See spark.yarn.config.gatewayPath. + + spark.yarn.security.tokens.${service}.enabled + true + + Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. + By default, delegation tokens for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run. +

+ Currently supported services are: hive, hbase + + # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. +- In `--master yarn --deploy-mode cluster`, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sparkr.md b/docs/sparkr.md index 4385a4eeacd5..7139d16b4a06 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -230,3 +231,37 @@ head(teenagers) {% endhighlight %} + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +

+{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6c317175d327..3af171f10b2f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -28,7 +28,7 @@ All of the examples on this page use sample data included in the Spark distribut
The entry point into all functionality in Spark SQL is the -[`SQLContext`](api/scala/index.html#org.apache.spark.sql.`SQLContext`) class, or one of its +[`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} @@ -1551,7 +1551,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a (SerDes) in order to access data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +the query on a YARN cluster (`--master yarn --deploy-mode cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879..7571e22575ef 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
// Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index dbfdb619f89e..118ced298f4b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -1702,7 +1702,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example +context and set up the DStreams. See the Java example [JavaRecoverableNetworkWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). This example appends the word counts of network data into a file. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e52..b32d9c12cd7e 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -87,6 +87,12 @@ run it with `--help`. Here are a few examples of common options: --total-executor-cores 100 \ /path/to/examples.jar \ 1000 + +# Run a Python application on a Spark Standalone cluster +./bin/spark-submit \ + --master spark://207.184.161.138:7077 \ + examples/src/main/python/pi.py \ + 1000 # Run on a Spark Standalone cluster in cluster deploy mode with supervise ./bin/spark-submit \ @@ -99,7 +105,7 @@ run it with `--help`. Here are a few examples of common options: /path/to/examples.jar \ 1000 -# Run on a YARN cluster +# Run on a YARN cluster without --deploy mode export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ @@ -108,12 +114,6 @@ export HADOOP_CONF_DIR=XXX --num-executors 50 \ /path/to/examples.jar \ 1000 - -# Run a Python application on a Spark Standalone cluster -./bin/spark-submit \ - --master spark://207.184.161.138:7077 \ - examples/src/main/python/pi.py \ - 1000 {% endhighlight %} # Master URLs @@ -140,7 +140,6 @@ cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or - # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca577..a377694507d2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -124,7 +124,7 @@ public String uid() { /** * Param for max number of iterations - *

+ *

* NOTE: The usual way to add a parameter to a model or algorithm is to include: * - val myParamName: ParamType * - def getMyParamName @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

+ *

* This is used for the defaul implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected @@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 02f58f48b07a..99b63a2590ae 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -45,7 +45,7 @@ * Usage: JavaStatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. - *

+ *

* To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 6ef188a220c5..ea20678b9aca 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -23,8 +23,8 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ @@ -37,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 091b64d8c4af..d75bc6daac13 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -23,8 +23,9 @@ https://flume.apache.org/documentation.html and then run the example - `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ - spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + `$ bin/spark-submit --jars \ + external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ + examples/src/main/python/streaming/flume_wordcount.py \ localhost 12345 """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index b178e7899b5e..8d697f620f46 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -23,8 +23,9 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 000000000000..abf9c0e21d30 --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars \ + external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ + examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py index dcd6a0fc6ff9..b3808907f74a 100644 --- a/examples/src/main/python/streaming/queue_stream.py +++ b/examples/src/main/python/streaming/queue_stream.py @@ -36,8 +36,8 @@ # Create the queue through which RDDs can be pushed to # a QueueInputDStream rddQueue = [] - for i in xrange(5): - rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + for i in range(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)] # Create the QueueInputDStream and use it do some processing inputStream = ssc.queueStream(rddQueue) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe56..340c3559b15e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -179,7 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cd411397a4b9..3ae53e57dbdb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -76,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 57ffe3dd2524..cc6bce3cb7c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -100,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index e43a6f2864c7..69691ae297f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 13189595d1d6..e05e4318969c 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -68,6 +68,11 @@ commons-codec provided + + commons-lang + commons-lang + provided + commons-net commons-net @@ -88,6 +93,12 @@ avro-ipc provided + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + org.scala-lang scala-library diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 977514fa5a1e..36342f37bb2e 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -47,6 +47,90 @@ ${project.version} provided + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + net.jpountz.lz4 + lz4 + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 1a9d78c0d4f5..ea5f842c6caf 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -197,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 000000000000..f3e3f93e7ed5 --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,176 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784..69b309876a0d 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -78,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 000000000000..ecab5b360eb3 --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba3..38a1114863d1 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869b..a6a9249db8ed 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 000000000000..1a371b700882 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 70d2c9c58f54..3ca538608f69 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -47,6 +47,85 @@ ${project.version} provided + + + com.fasterxml.jackson.core + jackson-databind + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9a..521b53e230c4 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 2bcf8684b8b8..a3ad6bed1c99 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3b..a4e3acc674f3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -32,7 +32,7 @@ class Main { /** * Usage: Main [class] [class args] - *

+ *

* This CLI works in two different modes: *

    *
  • "spark-submit": if class is "org.apache.spark.deploy.SparkSubmit", the @@ -42,7 +42,7 @@ class Main { * * This class works in tandem with the "bin/spark-class" script on Unix-like systems, and * "bin/spark-class2.cmd" batch script on Windows to execute the final command. - *

    + *

    * On Unix-like systems, the output is a list of command arguments, separated by the NULL * character. On Windows, the output is a command line suitable for direct execution from the * script. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 5f95e2c74f90..931a24cfd4b1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -28,7 +28,7 @@ /** * Command builder for internal Spark classes. - *

    + *

    * This class handles building the command to launch all internal Spark classes except for * SparkSubmit (which is handled by {@link SparkSubmitCommandBuilder} class. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index c0f89c923069..57993405e47b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,12 +20,13 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern @@ -57,7 +58,8 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -187,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

    + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -277,4 +346,32 @@ public Process launch() throws IOException { return pb.start(); } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 87c43aa9980e..fc87814a59ed 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -25,11 +25,11 @@ /** * Special command builder for handling a CLI invocation of SparkSubmit. - *

    + *

    * This builder adds command line parsing compatible with SparkSubmit. It handles setting * driver-side options and special parsing behavior needed for the special-casing certain internal * Spark applications. - *

    + *

    * This class has also some special features to aid launching pyspark. */ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { @@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; + final List sparkArgs; private final boolean printHelp; /** diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 5779eb3fc0f7..6767cc507964 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -23,7 +23,7 @@ /** * Parser for spark-submit command line options. - *

    + *

    * This class encapsulates the parsing code for spark-submit command line options, so that there * is a single list of options that needs to be maintained (well, sort of, but it makes it harder * to break things). @@ -80,10 +80,10 @@ class SparkSubmitOptionParser { * This is the canonical list of spark-submit options. Each entry in the array contains the * different aliases for the same option; the first element of each entry is the "official" * name of the option, passed to {@link #handle(String, String)}. - *

    + *

    * Options not listed here nor in the "switch" list below will result in a call to * {@link $#handleUnknown(String)}. - *

    + *

    * These two arrays are visible for tests. */ final String[][] opts = { @@ -130,7 +130,7 @@ class SparkSubmitOptionParser { /** * Parse a list of spark-submit command line options. - *

    + *

    * See SparkSubmitArguments.scala for a more formal description of available options. * * @throws IllegalArgumentException If an error is found during parsing. diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca..d0c26dd05679 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,6 +20,7 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -35,8 +36,54 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dfoo=bar -Dtest.name=-testChildProcLauncher") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); diff --git a/make-distribution.sh b/make-distribution.sh index 4789b0e09cc8..04ad0052eb24 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.7.0" +TACHYON_VERSION="0.7.1" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" @@ -219,7 +219,6 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d287..a3e59401c5cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -198,6 +198,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 29598f3f05c2..6f70b96b17ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] ( override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a959926..3073a2a61ce8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -196,7 +196,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f55134d25885..21fbe38ca823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization { + with HasStandardization with HasThreshold { /** - * Version of setThresholds() for binary classification, available for backwards - * compatibility. + * Set threshold in binary classification, in range [0, 1]. * - * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`. + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. * - * Default is effectively 0.5. * @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value)) + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } /** - * Version of [[getThresholds()]] for binary classification, available for backwards - * compatibility. + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. * - * Param thresholds must have length 2 (or not be specified). - * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}. * @group getParam */ - def getThreshold: Double = { - if (isDefined(thresholds)) { - val thresholdValues = $(thresholds) - assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" + - " binary classification, but thresholds has length != 2." + - s" thresholds: ${thresholdValues.mkString(",")}") - 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1)) + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) } else { - 0.5 + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") } } + + override def validateParams(): Unit = { + checkThresholdConsistency() + } } /** * :: Experimental :: * Logistic regression. - * Currently, this class only supports binary classification. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ @Experimental class LogisticRegression(override val uid: String) @@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String) * Whether to fit an intercept term. * Default is true. * @group setParam - * */ + */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String) * is applied. In R's GLMNET package, the default behavior is true as well. * Default is true. * @group setParam - * */ + */ def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) @@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map { @@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] ( override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept @@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (score(features) > getThreshold) 1 else 0 } @@ -389,10 +468,11 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. val t = getThreshold val rawThreshold = if (t == 0.0) { Double.NegativeInfinity @@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] ( } override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8cd2103d7d5e..1e5b0bc4453e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -42,9 +42,6 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams ParamValidators.arrayLengthGt(1) ) - /** @group setParam */ - def setLayers(value: Array[Int]): this.type = set(layers, value) - /** @group getParam */ final def getLayers: Array[Int] = $(layers) @@ -61,33 +58,9 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams "it is adjusted to the size of this data. Recommended size is between 10 and 1000", ParamValidators.gt(0)) - /** @group setParam */ - def setBlockSize(value: Int): this.type = set(blockSize, value) - /** @group getParam */ final def getBlockSize: Int = $(blockSize) - /** - * Set the maximum number of iterations. - * Default is 100. - * @group setParam - */ - def setMaxIter(value: Int): this.type = set(maxIter, value) - - /** - * Set the convergence tolerance of iterations. - * Smaller value will lead to higher accuracy with the cost of more iterations. - * Default is 1E-4. - * @group setParam - */ - def setTol(value: Double): this.type = set(tol, value) - - /** - * Set the seed for weights initialization. - * @group setParam - */ - def setSeed(value: Long): this.type = set(seed, value) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) } @@ -131,11 +104,38 @@ private object LabelConverter { */ @Experimental class MultilayerPerceptronClassifier(override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { def this() = this(Identifiable.randomUID("mlpc")) + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) @@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) } } /** * :: Experimental :: - * Classifier model based on the Multilayer Perceptron. + * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers @@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String) * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml] ( +class MultilayerPerceptronClassificationModel private[ml] ( override val uid: String, - layers: Array[Int], - weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + val layers: Array[Int], + val weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } - override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { - copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911..c62e132f5d53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,14 +131,14 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String) // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1e50a895a9a0..fdd1851ae550 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -50,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassifier[ +abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -74,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[ * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassificationModel[ +abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 156050aaf7a4..11a6d7246833 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] ( override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc192add6ca1..47a18cdb31b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -27,14 +27,13 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils /** * Common params for KMeans and KMeansModel */ -private[clustering] trait KMeansParams - extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -45,31 +44,6 @@ private[clustering] trait KMeansParams /** @group getParam */ def getK: Int = $(k) - /** - * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Must be >= 1. Default: 1. - * @group param - */ - final val runs = new IntParam(this, "runs", - "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) - - /** @group getParam */ - def getRuns: Int = $(runs) - - /** - * Param the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - * Must be >= 0.0. Default: 1e-4 - * @group param - */ - final val epsilon = new DoubleParam(this, "epsilon", - "distance threshold within which we've consider centers to have converge", - (value: Double) => value >= 0.0) - - /** @group getParam */ - def getEpsilon: Double = $(epsilon) - /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ @@ -136,9 +110,9 @@ class KMeansModel private[ml] ( /** * :: Experimental :: - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ @Experimental class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { @@ -146,10 +120,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean setDefault( k -> 2, maxIter -> 20, - runs -> 1, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 5, - epsilon -> 1e-4) + tol -> 1e-4) override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -174,10 +147,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setRuns(value: Int): this.type = set(runs, value) - - /** @group setParam */ - def setEpsilon(value: Double): this.type = set(epsilon, value) + def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) @@ -191,8 +161,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean .setInitializationSteps($(initSteps)) .setMaxIterations($(maxIter)) .setSeed($(seed)) - .setEpsilon($(epsilon)) - .setRuns($(runs)) + .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 5d5cb7e94f45..56419a0a1595 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -40,8 +40,11 @@ class BinaryClassificationEvaluator(override val uid: String) * param for metric name in evaluation * @group param */ - val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR")) + new Param( + this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams) + } /** @group getParam */ def getMetricName: String = $(metricName) @@ -76,16 +79,17 @@ class BinaryClassificationEvaluator(override val uid: String) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { - case "areaUnderROC" => - metrics.areaUnderROC() - case "areaUnderPR" => - metrics.areaUnderPR() - case other => - throw new IllegalArgumentException(s"Does not support metric $other.") + case "areaUnderROC" => metrics.areaUnderROC() + case "areaUnderPR" => metrics.areaUnderPR() } metrics.unpersist() metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "areaUnderROC" => true + case "areaUnderPR" => true + } + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index e56c946a063e..13bd3307f8a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -46,5 +46,12 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double + /** + * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) + * or minimized (false). + * A given evaluator may support multiple metrics which may be maximized or minimized. + */ + def isLargerBetter: Boolean = true + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 44f779c1908d..f73d2345078e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -81,5 +81,13 @@ class MulticlassClassificationEvaluator (override val uid: String) metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "f1" => true + case "precision" => true + case "recall" => true + case "weightedPrecision" => true + case "weightedRecall" => true + } + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 01c000b47514..d21c88ab9b10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -73,17 +73,20 @@ final class RegressionEvaluator(override val uid: String) } val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { - case "rmse" => - -metrics.rootMeanSquaredError - case "mse" => - -metrics.meanSquaredError - case "r2" => - metrics.r2 - case "mae" => - -metrics.meanAbsoluteError + case "rmse" => metrics.rootMeanSquaredError + case "mse" => metrics.meanSquaredError + case "r2" => metrics.r2 + case "mae" => metrics.meanAbsoluteError } metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "rmse" => false + case "mse" => false + case "r2" => true + case "mae" => false + } + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc355..6fdf25b015b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala new file mode 100644 index 000000000000..49028e4b8506 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.collection.OpenHashMap + +/** + * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. + */ +private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol { + + /** + * Max size of the vocabulary. + * CountVectorizer will build a vocabulary that only considers the top + * vocabSize terms ordered by term frequency across the corpus. + * + * Default: 2^18^ + * @group param + */ + val vocabSize: IntParam = + new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0)) + + /** @group getParam */ + def getVocabSize: Int = $(vocabSize) + + /** + * Specifies the minimum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer >= 1, this specifies the number of documents the term must appear in; + * if this is a double in [0,1), then this specifies the fraction of documents. + * + * Default: 1 + * @group param + */ + val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMinDF: Double = $(minDF) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + /** + * Filter to ignore rare words in a document. For each document, terms with + * frequency/count less than the given threshold are ignored. + * If this is an integer >= 1, then this specifies a count (of times the term must appear + * in the document); + * if this is a double in [0,1), then this specifies a fraction (out of the document's token + * count). + * + * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not + * affect fitting. + * + * Default: 1 + * @group param + */ + val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given threshold are" + + " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" + + " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" + + " of the document's token count). Note that the parameter is only used in transform of" + + " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0)) + + setDefault(minTF -> 1) + + /** @group getParam */ + def getMinTF: Double = $(minTF) +} + +/** + * :: Experimental :: + * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. + */ +@Experimental +class CountVectorizer(override val uid: String) + extends Estimator[CountVectorizerModel] with CountVectorizerParams { + + def this() = this(Identifiable.randomUID("cntVec")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVocabSize(value: Int): this.type = set(vocabSize, value) + + /** @group setParam */ + def setMinDF(value: Double): this.type = set(minDF, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + setDefault(vocabSize -> (1 << 18), minDF -> 1) + + override def fit(dataset: DataFrame): CountVectorizerModel = { + transformSchema(dataset.schema, logging = true) + val vocSize = $(vocabSize) + val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val minDf = if ($(minDF) >= 1.0) { + $(minDF) + } else { + $(minDF) * input.cache().count() + } + val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val wc = new OpenHashMap[String, Long] + tokens.foreach { w => + wc.changeValue(w, 1L, _ + 1L) + } + wc.map { case (word, count) => (word, (count, 1)) } + }.reduceByKey { case ((wc1, df1), (wc2, df2)) => + (wc1 + wc2, df1 + df2) + }.filter { case (word, (wc, df)) => + df >= minDf + }.map { case (word, (count, dfCount)) => + (word, count) + }.cache() + val fullVocabSize = wordCounts.count() + val vocab: Array[String] = { + val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocSize) + } + tmpSortedWC.map(_._1) + } + + require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") + copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Converts a text document to a sparse vector of token counts. + * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. + */ +@Experimental +class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) + extends Model[CountVectorizerModel] with CountVectorizerParams { + + def this(vocabulary: Array[String]) = { + this(Identifiable.randomUID("cntVecModel"), vocabulary) + set(vocabSize, vocabulary.length) + } + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMinTF(value: Double): this.type = set(minTF, value) + + /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ + private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None + + override def transform(dataset: DataFrame): DataFrame = { + if (broadcastDict.isEmpty) { + val dict = vocabulary.zipWithIndex.toMap + broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) + } + val dictBr = broadcastDict.get + val minTf = $(minTF) + val vectorizer = udf { (document: Seq[String]) => + val termCounts = new OpenHashMap[Int, Double] + var tokenCount = 0L + document.foreach { term => + dictBr.value.get(term) match { + case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) + case None => // ignore terms not in the vocabulary + } + tokenCount += 1 + } + val effectiveMinTF = if (minTf >= 1.0) { + minTf + } else { + tokenCount * minTf + } + Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq) + } + dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): CountVectorizerModel = { + val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala deleted file mode 100644 index 6b77de89a033..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.feature - -import scala.collection.mutable - -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector} -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} - -/** - * :: Experimental :: - * Converts a text document to a sparse vector of token counts. - * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. - */ -@Experimental -class CountVectorizerModel (override val uid: String, val vocabulary: Array[String]) - extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] { - - def this(vocabulary: Array[String]) = - this(Identifiable.randomUID("cntVec"), vocabulary) - - /** - * Corpus-specific filter to ignore scarce words in a document. For each document, terms with - * frequency (count) less than the given threshold are ignored. - * Default: 1 - * @group param - */ - val minTermFreq: IntParam = new IntParam(this, "minTermFreq", - "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " + - "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1)) - - /** @group setParam */ - def setMinTermFreq(value: Int): this.type = set(minTermFreq, value) - - /** @group getParam */ - def getMinTermFreq: Int = $(minTermFreq) - - setDefault(minTermFreq -> 1) - - override protected def createTransformFunc: Seq[String] => Vector = { - val dict = vocabulary.zipWithIndex.toMap - document => - val termCounts = mutable.HashMap.empty[Int, Double] - document.foreach { term => - dict.get(term) match { - case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0) - case None => // ignore terms not in the vocabulary - } - } - Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq) - } - - override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") - } - - override protected def outputDataType: DataType = new VectorUDT() - - override def copy(extra: ParamMap): CountVectorizerModel = { - val copied = new CountVectorizerModel(uid, vocabulary) - copyValues(copied, extra) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde80810580..938447447a0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -114,6 +114,6 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b30adf3df48d..1b494ec8b172 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -41,6 +41,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val min: DoubleParam = new DoubleParam(this, "min", "lower bound of the output feature range") + /** @group getParam */ + def getMin: Double = $(min) + /** * upper bound after transformation, shared by all features * Default: 1.0 @@ -49,6 +52,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val max: DoubleParam = new DoubleParam(this, "max", "upper bound of the output feature range") + /** @group getParam */ + def getMax: Double = $(max) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType @@ -115,6 +121,9 @@ class MinMaxScaler(override val uid: String) * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ @Experimental @@ -136,7 +145,6 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray @@ -165,6 +173,6 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf30..539084704b65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -125,6 +125,6 @@ class PCAModel private[ml] ( override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d5360c9217ea..a7fa50444209 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -33,11 +33,6 @@ import org.apache.spark.sql.types._ * Base trait for [[RFormula]] and [[RFormulaModel]]. */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) @@ -47,8 +42,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula - * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the + * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { @@ -71,6 +66,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 000000000000..95e430563873 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db3e..f6d0b0c0e9e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -136,6 +136,6 @@ class StandardScalerModel private[ml] ( override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f..5d77ea08db65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * stop words list @@ -100,7 +100,7 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ebfa97253235..24250e4c4cf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -58,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -65,13 +68,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -91,14 +97,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * * NOTE: During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. + * + * @param labels Ordered list of labels, corresponding to indices to be assigned */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, - labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length @@ -111,6 +122,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -128,14 +143,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } @@ -150,36 +175,26 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) - } - - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) + copyValues(copied, extra).setParent(parent) } } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -239,7 +254,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -250,7 +265,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5f..61b925c0fdc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { @@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 772bebeff214..c5c227227079 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String) case features: SparseVector => features.slice(inds) } } - dataset.withColumn($(outputCol), - slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) } /** Get the feature indices in order: indices, names */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 29acc3eb5865..5af775a4159a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,6 +221,6 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala new file mode 100644 index 000000000000..4571ab26800c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, VectorAssembler} +import org.apache.spark.sql.DataFrame + +/** + * == Feature transformers == + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as [[Transformer]]s, which transform one [[DataFrame]] + * into another, e.g., [[HashingTF]]. + * Some feature transformers are implemented as [[Estimator]]s, because the transformation requires + * some aggregated information of the dataset, e.g., document frequencies in [[IDF]]. + * For those feature transformers, calling [[Estimator!.fit]] is required to obtain the model first, + * e.g., [[IDFModel]], in order to apply transformation. + * The transformation is usually done by appending new columns to the input [[DataFrame]], so all + * input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * [[Pipeline]] can be used to chain feature transformers, and [[VectorAssembler]] can be used to + * combine multiple feature transformations, for example: + * + * {{{ + * import org.apache.spark.ml.feature._ + * import org.apache.spark.ml.Pipeline + * + * // a DataFrame with three columns: id (integer), text (string), and rating (double). + * val df = sqlContext.createDataFrame(Seq( + * (0, "Hi I heard about Spark", 3.0), + * (1, "I wish Java could use case classes", 4.0), + * (2, "Logistic regression models are neat", 4.0) + * )).toDF("id", "text", "rating") + * + * // define feature transformers + * val tok = new RegexTokenizer() + * .setInputCol("text") + * .setOutputCol("words") + * val sw = new StopWordsRemover() + * .setInputCol("words") + * .setOutputCol("filtered_words") + * val tf = new HashingTF() + * .setInputCol("filtered_words") + * .setOutputCol("tf") + * .setNumFeatures(10000) + * val idf = new IDF() + * .setInputCol("tf") + * .setOutputCol("tf_idf") + * val assembler = new VectorAssembler() + * .setInputCols(Array("tf_idf", "rating")) + * .setOutputCol("features") + * + * // assemble and fit the feature transformation pipeline + * val pipeline = new Pipeline() + * .setStages(Array(tok, sw, tf, idf, assembler)) + * val model = pipeline.fit(df) + * + * // save transformed features with raw data + * model.transform(df) + * .select("id", "text", "rating", "features") + * .write.format("parquet").save("/output/path") + * }}} + * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see [[http://scikit-learn.org/stable/modules/preprocessing.html scikit-learn.preprocessing]] + */ +package object feature diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c..91c0a5631319 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a97c8059b8d4..8c16c6149b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,20 +45,24 @@ private[shared] object SharedParamsCodeGen { " These probabilities should be treated as confidences, not precise probabilities.", Some("\"probability\"")), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original probability" + " of that class and t is the class' threshold.", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)"), + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), @@ -66,7 +70,9 @@ private[shared] object SharedParamsCodeGen { " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f332630c32f1..c26768953e3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { @@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params { */ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + setDefault(threshold, 0.5) + /** @group getParam */ def getThreshold: Double = $(threshold) } @@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params { final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) /** @group getParam */ - final def getThresholds: Array[Double] = $(thresholds) + def getThresholds: Array[Double] = $(thresholds) } /** @@ -247,6 +249,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ @@ -325,4 +342,19 @@ private[ml] trait HasStepSize extends Params { /** @group getParam */ final def getStepSize: Double = $(stepSize) } + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a2..7db8ad8d2791 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -219,7 +219,7 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index dc94a1401454..a2bcd67401d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc320273..b66e61f37dd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -185,7 +185,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index f570590960a6..0f33bae30e62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} @@ -35,19 +35,7 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with Logging { - - /** - * Param for weight column name (default: none). - * @group param - */ - // TODO: Move weightCol to sharedParams. - final val weightCol: Param[String] = - new Param[String](this, "weightCol", - "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") - - /** @group getParam */ - final def getWeightCol: String = $(weightCol) + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92d819bad865..884003eb3852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] ( override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index db75c0d26392..2f36da371f57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4b5..0679bfd0f3ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -100,7 +100,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] @@ -160,6 +162,6 @@ class CrossValidatorModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c0edc730b6fd..73a14b831015 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -99,7 +99,9 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index ddd34a54503a..bd213e7362e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,11 +19,19 @@ package org.apache.spark.ml.util import java.util.UUID +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: + * * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. */ -private[spark] trait Identifiable { +@DeveloperApi +trait Identifiable { /** * An immutable unique ID for the object and its derivatives. @@ -33,7 +41,11 @@ private[spark] trait Identifiable { override def toString: String = uid } -private[spark] object Identifiable { +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index ba73024e3c04..a29b425a71fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.classification import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -36,8 +36,8 @@ trait ClassificationModel extends Serializable { * * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -45,16 +45,16 @@ trait ClassificationModel extends Serializable { * * @param testData array representing a single data point * @return predicted category from the trained model - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 268642ac6a2f..e03e662227d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} @@ -85,8 +85,8 @@ class LogisticRegressionModel ( * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -97,8 +97,8 @@ class LogisticRegressionModel ( * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. * It is only used for binary classification. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def getThreshold: Option[Double] = threshold @@ -106,8 +106,8 @@ class LogisticRegressionModel ( * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. * It is only used for binary classification. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None @@ -158,9 +158,7 @@ class LogisticRegressionModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures, numClasses, weights, intercept, threshold) @@ -168,9 +166,7 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } @@ -178,9 +174,7 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Loader[LogisticRegressionModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -261,8 +255,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -284,8 +278,8 @@ object LogisticRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -306,8 +300,8 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -324,8 +318,8 @@ object LogisticRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int): LogisticRegressionModel = { @@ -361,8 +355,8 @@ class LogisticRegressionWithLBFGS * Set the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def setNumClasses(numClasses: Int): this.type = { require(numClasses > 1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 2df91c09421e..dab369207cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} @@ -444,8 +445,8 @@ object NaiveBayes { * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint]): NaiveBayesModel = { new NaiveBayes().run(input) } @@ -460,8 +461,8 @@ object NaiveBayes { * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda, Multinomial).run(input) } @@ -483,8 +484,8 @@ object NaiveBayes { * * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5b54feeb1046..5f8726986357 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ @@ -46,8 +46,8 @@ class SVMModel ( * Sets the threshold that separates positive predictions from negative predictions. An example * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -57,16 +57,16 @@ class SVMModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def getThreshold: Option[Double] = threshold /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None @@ -84,9 +84,7 @@ class SVMModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) @@ -94,9 +92,7 @@ class SVMModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } @@ -104,9 +100,7 @@ class SVMModel ( object SVMModel extends Loader[SVMModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): SVMModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -185,8 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -209,8 +203,8 @@ object SVMWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -231,8 +225,8 @@ object SVMWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -250,8 +244,8 @@ object SVMWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 0.01, 1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index e459367333d2..fcc9dfecac54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -63,6 +63,7 @@ class GaussianMixture private ( * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ + @Since("1.3.0") def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -72,10 +73,12 @@ class GaussianMixture private ( // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - /** Set the initial GMM starting point, bypassing the random initialization. - * You must call setK() prior to calling this method, and the condition - * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + /** + * Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException */ + @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { if (model.k == k) { initialModel = Some(model) @@ -85,31 +88,47 @@ class GaussianMixture private ( this } - /** Return the user supplied initial GMM, if supplied */ + /** + * Return the user supplied initial GMM, if supplied + */ + @Since("1.3.0") def getInitialModel: Option[GaussianMixtureModel] = initialModel - /** Set the number of Gaussians in the mixture model. Default: 2 */ + /** + * Set the number of Gaussians in the mixture model. Default: 2 + */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this } - /** Return the number of Gaussians in the mixture model */ + /** + * Return the number of Gaussians in the mixture model + */ + @Since("1.3.0") def getK: Int = k - /** Set the maximum number of iterations to run. Default: 100 */ + /** + * Set the maximum number of iterations to run. Default: 100 + */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Return the maximum number of iterations to run */ + /** + * Return the maximum number of iterations to run + */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this @@ -119,18 +138,28 @@ class GaussianMixture private ( * Return the largest change in log-likelihood at which convergence is * considered to have occurred. */ + @Since("1.3.0") def getConvergenceTol: Double = convergenceTol - /** Set the random seed */ + /** + * Set the random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this } - /** Return the random seed */ + /** + * Return the random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Perform expectation maximization */ + /** + * Perform expectation maximization + */ + @Since("1.3.0") def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -204,7 +233,10 @@ class GaussianMixture private ( new GaussianMixtureModel(weights, gaussians) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) private def updateWeightsAndGaussians( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 76aeebd703d4..1a10a8b62421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -44,6 +44,7 @@ import org.apache.spark.sql.{SQLContext, Row} * @param gaussians Array of MultivariateGaussian where gaussians(i) represents * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ +@Since("1.3.0") @Experimental class GaussianMixtureModel( val weights: Array[Double], @@ -53,26 +54,39 @@ class GaussianMixtureModel( override protected def formatVersion = "1.0" + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) } - /** Number of gaussians in mixture */ + /** + * Number of gaussians in mixture + */ + @Since("1.3.0") def k: Int = weights.length - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.3.0") def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - /** Maps given point to its cluster index. */ + /** + * Maps given point to its cluster index. + */ + @Since("1.5.0") def predict(point: Vector): Int = { val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) r.indexOf(r.max) } - /** Java-friendly version of [[predict()]] */ + /** + * Java-friendly version of [[predict()]] + */ + @Since("1.4.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -80,6 +94,7 @@ class GaussianMixtureModel( * Given the input vectors, return the membership value of each vector * to all mixture components. */ + @Since("1.3.0") def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) @@ -92,6 +107,7 @@ class GaussianMixtureModel( /** * Given the input vector, return the membership values to all mixture components. */ + @Since("1.4.0") def predictSoft(point: Vector): Array[Double] = { computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) } @@ -115,6 +131,7 @@ class GaussianMixtureModel( } } +@Since("1.4.0") @Experimental object GaussianMixtureModel extends Loader[GaussianMixtureModel] { @@ -165,6 +182,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } } + @Since("1.4.0") override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0a65403f4ec9..3e9545a74bef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -50,14 +50,19 @@ class KMeans private ( * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. */ + @Since("0.8.0") def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). */ + @Since("1.4.0") def getK: Int = k - /** Set the number of clusters to create (k). Default: 2. */ + /** + * Set the number of clusters to create (k). Default: 2. + */ + @Since("0.8.0") def setK(k: Int): this.type = { this.k = k this @@ -66,9 +71,13 @@ class KMeans private ( /** * Maximum number of iterations to run. */ + @Since("1.4.0") def getMaxIterations: Int = maxIterations - /** Set maximum number of iterations to run. Default: 20. */ + /** + * Set maximum number of iterations to run. Default: 20. + */ + @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -77,6 +86,7 @@ class KMeans private ( /** * The initialization algorithm. This can be either "random" or "k-means||". */ + @Since("1.4.0") def getInitializationMode: String = initializationMode /** @@ -84,6 +94,7 @@ class KMeans private ( * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ + @Since("0.8.0") def setInitializationMode(initializationMode: String): this.type = { KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode @@ -94,6 +105,7 @@ class KMeans private ( * :: Experimental :: * Number of runs of the algorithm to execute in parallel. */ + @Since("1.4.0") @Experimental def getRuns: Int = runs @@ -103,6 +115,7 @@ class KMeans private ( * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. */ + @Since("0.8.0") @Experimental def setRuns(runs: Int): this.type = { if (runs <= 0) { @@ -115,12 +128,14 @@ class KMeans private ( /** * Number of steps for the k-means|| initialization mode */ + @Since("1.4.0") def getInitializationSteps: Int = initializationSteps /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. */ + @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") @@ -132,12 +147,14 @@ class KMeans private ( /** * The distance threshold within which we've consider centers to have converged. */ + @Since("1.4.0") def getEpsilon: Double = epsilon /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. */ + @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this @@ -146,9 +163,13 @@ class KMeans private ( /** * The random seed for cluster initialization. */ + @Since("1.4.0") def getSeed: Long = seed - /** Set the random seed for cluster initialization. */ + /** + * Set the random seed for cluster initialization. + */ + @Since("1.4.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -163,6 +184,7 @@ class KMeans private ( * The condition model.k == this.k must be met, failure results * in an IllegalArgumentException. */ + @Since("1.4.0") def setInitialModel(model: KMeansModel): this.type = { require(model.k == k, "mismatched cluster count") initialModel = Some(model) @@ -173,6 +195,7 @@ class KMeans private ( * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ + @Since("0.8.0") def run(data: RDD[Vector]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { @@ -431,10 +454,13 @@ class KMeans private ( /** * Top-level methods for calling K-means clustering. */ +@Since("0.8.0") object KMeans { // Initialization mode names + @Since("0.8.0") val RANDOM = "random" + @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" /** @@ -447,6 +473,7 @@ object KMeans { * @param initializationMode initialization model, either "random" or "k-means||" (default). * @param seed random seed value for cluster initialization */ + @Since("1.3.0") def train( data: RDD[Vector], k: Int, @@ -471,6 +498,7 @@ object KMeans { * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -487,6 +515,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -497,6 +526,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 96359024fa22..e425ecdd481c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable @@ -35,28 +36,44 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ +@Since("0.8.0") class KMeansModel ( val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { - /** A Java-friendly constructor that takes an Iterable of Vectors. */ + /** + * A Java-friendly constructor that takes an Iterable of Vectors. + */ + @Since("1.4.0") def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) - /** Total number of clusters. */ + /** + * Total number of clusters. + */ + @Since("0.8.0") def k: Int = clusterCenters.length - /** Returns the cluster index that a given point belongs to. */ + /** + * Returns the cluster index that a given point belongs to. + */ + @Since("0.8.0") def predict(point: Vector): Int = { KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(centersWithNorm) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + */ + @Since("1.0.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -64,6 +81,7 @@ class KMeansModel ( * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. */ + @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(centersWithNorm) @@ -73,6 +91,7 @@ class KMeansModel ( private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { KMeansModel.SaveLoadV1_0.save(sc, this, path) } @@ -80,7 +99,10 @@ class KMeansModel ( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object KMeansModel extends Loader[KMeansModel] { + + @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { KMeansModel.SaveLoadV1_0.load(sc, path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ab124e6d77c5..92a321afb0ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -44,6 +44,7 @@ import org.apache.spark.util.Utils * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] */ +@Since("1.3.0") @Experimental class LDA private ( private var k: Int, @@ -54,19 +55,26 @@ class LDA private ( private var checkpointInterval: Int, private var ldaOptimizer: LDAOptimizer) extends Logging { + /** + * Constructs a LDA instance with default parameters. + */ + @Since("1.3.0") def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. + * */ + @Since("1.3.0") def getK: Int = k /** * Number of topics to infer. I.e., the number of soft cluster centers. * (default = 10) */ + @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") this.k = k @@ -79,7 +87,26 @@ class LDA private ( * * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Vector = this.docConcentration + @Since("1.5.0") + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + @Since("1.3.0") + def getDocConcentration: Double = { + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter + } else { + require(docConcentration.toArray.forall(_ == parameter)) + parameter + } + } /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' @@ -105,24 +132,44 @@ class LDA private ( * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. */ + @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Replicates Double to create a symmetric prior */ + /** + * Replicates a [[Double]] docConcentration to create a symmetric prior. + */ + @Since("1.3.0") def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } - /** Alias for [[getDocConcentration]] */ - def getAlpha: Vector = getDocConcentration + /** + * Alias for [[getAsymmetricDocConcentration]] + */ + @Since("1.5.0") + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration + + /** + * Alias for [[getDocConcentration]] + */ + @Since("1.3.0") + def getAlpha: Double = getDocConcentration - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.5.0") def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + */ + @Since("1.3.0") def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) /** @@ -134,6 +181,7 @@ class LDA private ( * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ + @Since("1.3.0") def getTopicConcentration: Double = this.topicConcentration /** @@ -158,35 +206,50 @@ class LDA private ( * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. */ + @Since("1.3.0") def setTopicConcentration(topicConcentration: Double): this.type = { this.topicConcentration = topicConcentration this } - /** Alias for [[getTopicConcentration]] */ + /** + * Alias for [[getTopicConcentration]] + */ + @Since("1.3.0") def getBeta: Double = getTopicConcentration - /** Alias for [[setTopicConcentration()]] */ + /** + * Alias for [[setTopicConcentration()]] + */ + @Since("1.3.0") def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Maximum number of iterations for learning. * (default = 20) */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def getSeed: Long = seed - /** Random seed */ + /** + * Random seed + */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -195,6 +258,7 @@ class LDA private ( /** * Period (in iterations) between checkpoints. */ + @Since("1.3.0") def getCheckpointInterval: Int = checkpointInterval /** @@ -205,6 +269,7 @@ class LDA private ( * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ + @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -216,6 +281,7 @@ class LDA private ( * * LDAOptimizer used to perform the actual calculation */ + @Since("1.4.0") @DeveloperApi def getOptimizer: LDAOptimizer = ldaOptimizer @@ -224,6 +290,7 @@ class LDA private ( * * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) */ + @Since("1.4.0") @DeveloperApi def setOptimizer(optimizer: LDAOptimizer): this.type = { this.ldaOptimizer = optimizer @@ -234,6 +301,7 @@ class LDA private ( * Set the LDAOptimizer used to perform the actual calculation by algorithm name. * Currently "em", "online" are supported. */ + @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = optimizerName.toLowerCase match { @@ -254,6 +322,7 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ + @Since("1.3.0") def run(documents: RDD[(Long, Vector)]): LDAModel = { val state = ldaOptimizer.initialize(documents, this) var iter = 0 @@ -268,7 +337,10 @@ class LDA private ( state.getLDAModel(iterationTimes) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + */ + @Since("1.3.0") def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 33babda69bbb..667374a2bc41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -25,9 +25,8 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -190,14 +189,19 @@ class LocalLDAModel private[clustering] ( val topics: Matrix, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel with Serializable { + @Since("1.3.0") override def k: Int = topics.numCols + @Since("1.3.0") override def vocabSize: Int = topics.numRows + @Since("1.3.0") override def topicsMatrix: Matrix = topics + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => @@ -210,6 +214,7 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -224,10 +229,19 @@ class LocalLDAModel private[clustering] ( * @param documents test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ + @Since("1.5.0") def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + /** + * Java-friendly version of [[logLikelihood]] + */ + @Since("1.5.0") + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Calculate an upper bound bound on perplexity. (Lower is better.) * See Equation (16) in original Online LDA paper. @@ -235,6 +249,7 @@ class LocalLDAModel private[clustering] ( * @param documents test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. */ + @Since("1.5.0") def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusTokenCount = documents .map { case (_, termCounts) => termCounts.toArray.sum } @@ -242,6 +257,12 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } + /** Java-friendly version of [[logPerplexity]] */ + @Since("1.5.0") + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Estimate the variational likelihood bound of from `documents`: * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] @@ -316,6 +337,7 @@ class LocalLDAModel private[clustering] ( * @param documents documents to predict topic mixture distributions for * @return An RDD of (document ID, topic mixture distribution for document) */ + @Since("1.3.0") // TODO: declare in LDAModel and override once implemented in DistributedLDAModel def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { // Double transpose because dirichletExpectation normalizes by row and we need to normalize @@ -341,8 +363,17 @@ class LocalLDAModel private[clustering] ( } } -} + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } +} @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -396,7 +427,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() val vocabSize = topics(0).getAs[Vector](0).size - val k = topics.size + val k = topics.length val brzTopics = BDM.zeros[Double](vocabSize, k) topics.foreach { case Row(vec: Vector, ind: Int) => @@ -409,6 +440,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } } + @Since("1.5.0") override def load(sc: SparkContext, path: String): LocalLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -455,8 +487,9 @@ class DistributedLDAModel private[clustering] ( val vocabSize: Int, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ @@ -465,6 +498,7 @@ class DistributedLDAModel private[clustering] ( * The local model stores the inferred topics but not the topic distributions for training * documents. */ + @Since("1.3.0") def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -475,6 +509,7 @@ class DistributedLDAModel private[clustering] ( * * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. */ + @Since("1.3.0") override lazy val topicsMatrix: Matrix = { // Collect row-major topics val termTopicCounts: Array[(Int, TopicCounts)] = @@ -493,6 +528,7 @@ class DistributedLDAModel private[clustering] ( Matrices.fromBreeze(brzTopics) } + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights @@ -532,6 +568,7 @@ class DistributedLDAModel private[clustering] ( * (IDs for the documents, weights of the topic in these documents). * For each topic, documents are sorted in order of decreasing topic weights. */ + @Since("1.5.0") def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { val numTopics = k val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = @@ -558,6 +595,50 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most + * likely topic generating each term? + * + * @return RDD of (doc ID, assignment of top topic index for each term), + * where the assignment is specified via a pair of zippable arrays + * (term indices, topic indices). Note that terms will be omitted if not present in + * the document. + */ + lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { + // For reference, compare the below code with the core part of EMLDAOptimizer.next(). + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration(0) + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit = + (edgeContext) => { + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions). + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) + // For this (doc j, term w), send top topic k to doc vertex. + val topTopic: Int = argmax(scaledTopicDistribution) + val term: Int = index2term(edgeContext.dstId) + edgeContext.sendToSrc((Array(term), Array(topTopic))) + } + val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) = + (terms_topics0, terms_topics1) => { + (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val perDocAssignments = + graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex) + perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) => + // TODO: Avoid zip, which is inefficient. + val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip + (docID, sortedTerms.toArray, sortedTopics.toArray) + } + } + + /** Java-friendly version of [[topicAssignments]] */ + lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { + topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -571,6 +652,7 @@ class DistributedLDAModel private[clustering] ( * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the * hyperparameters. */ + @Since("1.3.0") lazy val logLikelihood: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -595,8 +677,9 @@ class DistributedLDAModel private[clustering] ( /** * Log probability of the current parameter estimate: - * log P(topics, topic distributions for docs | alpha, eta) + * log P(topics, topic distributions for docs | alpha, eta) */ + @Since("1.3.0") lazy val logPrior: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -628,13 +711,17 @@ class DistributedLDAModel private[clustering] ( * * @return RDD of (document ID, topic distribution) pairs */ + @Since("1.3.0") def topicDistributions: RDD[(Long, Vector)] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) } } - /** Java-friendly version of [[topicDistributions]] */ + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.4.1") def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } @@ -643,6 +730,7 @@ class DistributedLDAModel private[clustering] ( * For each document, return the top k weighted topics for that document and their weights. * @return RDD of (doc ID, topic indices, topic weights) */ + @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => val topIndices = argtopk(topicCounts, k) @@ -656,11 +744,24 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Java-friendly version of [[topTopicsPerDocument]] + */ + @Since("1.5.0") + def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? override protected def formatVersion = "1.0" + /** + * Java-friendly version of [[topicDistributions]] + */ + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, @@ -756,11 +857,12 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, gammaShape, iterationTimes) + docConcentration, topicConcentration, iterationTimes, gammaShape) } } + @Since("1.5.0") override def load(sc: SparkContext, path: String): DistributedLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -774,10 +876,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { - case (className, "1.0") if className == classNameV1_0 => { + case (className, "1.0") if className == classNameV1_0 => DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray, gammaShape) - } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index afba2866c704..5c2aae6403be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, su import breeze.numerics.{trigamma, abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer @@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can * hold optimizer-specific parameters for users to set. */ +@Since("1.4.0") @DeveloperApi sealed trait LDAOptimizer { @@ -73,8 +74,8 @@ sealed trait LDAOptimizer { * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. - * */ +@Since("1.4.0") @DeveloperApi final class EMLDAOptimizer extends LDAOptimizer { @@ -95,10 +96,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = lda.getDocConcentration(0) - require({ - lda.getDocConcentration.toArray.forall(_ == docConcentration) - }, "EMLDAOptimizer currently only supports symmetric document-topic priors") + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -168,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer { edgeContext.sendToDst((false, scaledTopicDistribution)) edgeContext.sendToSrc((false, scaledTopicDistribution)) } - // This is a hack to detect whether we could modify the values in-place. + // The Boolean is a hack to detect whether we could modify the values in-place. // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = (m0, m1) => { @@ -209,11 +208,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal - // conversion + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - 100, iterationTimes) + iterationTimes) } } @@ -228,6 +227,7 @@ final class EMLDAOptimizer extends LDAOptimizer { * Original Online LDA paper: * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. */ +@Since("1.4.0") @DeveloperApi final class OnlineLDAOptimizer extends LDAOptimizer { @@ -277,6 +277,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. */ + @Since("1.4.0") def getTau0: Double = this.tau0 /** @@ -284,6 +285,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * iterations count less. * Default: 1024, following the original Online LDA paper. */ + @Since("1.4.0") def setTau0(tau0: Double): this.type = { require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 @@ -293,6 +295,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Learning rate: exponential decay rate */ + @Since("1.4.0") def getKappa: Double = this.kappa /** @@ -300,6 +303,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * (0.5, 1.0] to guarantee asymptotic convergence. * Default: 0.51, based on the original Online LDA paper. */ + @Since("1.4.0") def setKappa(kappa: Double): this.type = { require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") this.kappa = kappa @@ -309,6 +313,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration */ + @Since("1.4.0") def getMiniBatchFraction: Double = this.miniBatchFraction /** @@ -321,6 +326,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * * Default: 0.05, i.e., 5% of total documents. */ + @Since("1.4.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") @@ -332,6 +338,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) * will be optimized during training. */ + @Since("1.5.0") def getOptimzeAlpha: Boolean = this.optimizeAlpha /** @@ -339,6 +346,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * * Default: false */ + @Since("1.5.0") def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { this.optimizeAlpha = optimizeAlpha this @@ -378,18 +386,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration.size == 1) { - if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") - Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") - lda.getDocConcentration.foreachActive { case (_, x) => + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } - lda.getDocConcentration + lda.getAsymmetricDocConcentration } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index f7e5ce1665fe..a9ba7b60bad0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -22,7 +22,7 @@ import breeze.numerics._ /** * Utility methods for LDA. */ -object LDAUtils { +private[clustering] object LDAUtils { /** * Log Sum Exp with overflow protection using the identity: * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 407e43a024a2..396b36f2f645 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -21,7 +21,7 @@ import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl @@ -40,11 +40,13 @@ import org.apache.spark.{Logging, SparkContext, SparkException} * @param k number of clusters * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ +@Since("1.3.0") @Experimental class PowerIterationClusteringModel( val k: Int, val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) } @@ -52,6 +54,7 @@ class PowerIterationClusteringModel( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) @@ -65,6 +68,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + /** + */ + @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ @@ -77,6 +83,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode dataRDD.write.parquet(Loader.dataPath(path)) } + @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) @@ -120,14 +127,17 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ - /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + /** + * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. */ + @Since("1.3.0") def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this @@ -136,6 +146,7 @@ class PowerIterationClustering private[clustering] ( /** * Set maximum number of iterations of the power iteration loop */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -145,6 +156,7 @@ class PowerIterationClustering private[clustering] ( * Set the initialization mode. This can be either "random" to use a random vector * as vertex properties, or "degree" to use normalized sum similarities. Default: random. */ + @Since("1.3.0") def setInitializationMode(mode: String): this.type = { this.initMode = mode match { case "random" | "degree" => mode @@ -165,6 +177,7 @@ class PowerIterationClustering private[clustering] ( * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ + @Since("1.5.0") def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { val w = normalize(graph) val w0 = initMode match { @@ -186,6 +199,7 @@ class PowerIterationClustering private[clustering] ( * * @return a [[PowerIterationClusteringModel]] that contains the clustering result */ + @Since("1.3.0") def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) val w0 = initMode match { @@ -198,6 +212,7 @@ class PowerIterationClustering private[clustering] ( /** * A Java-friendly version of [[PowerIterationClustering.run]]. */ + @Since("1.3.0") def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) : PowerIterationClusteringModel = { run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) @@ -221,6 +236,7 @@ class PowerIterationClustering private[clustering] ( } } +@Since("1.3.0") @Experimental object PowerIterationClustering extends Logging { @@ -230,6 +246,7 @@ object PowerIterationClustering extends Logging { * @param id node id * @param cluster assigned cluster id */ + @Since("1.3.0") @Experimental case class Assignment(id: Long, cluster: Int) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d9b34cec6489..41f2668ec6a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -63,14 +63,17 @@ import org.apache.spark.util.random.XORShiftRandom * such that at time t + h the discount applied to the data from t is 0.5. * The definition remains the same whether the time unit is given * as batches or points. - * */ +@Since("1.2.0") @Experimental class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { - /** Perform a k-means update on a batch of data. */ + /** + * Perform a k-means update on a batch of data. + */ + @Since("1.2.0") def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { // find nearest cluster to each point @@ -82,6 +85,7 @@ class StreamingKMeansModel( (p1._1, p1._2 + p2._2) } val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) .collect() @@ -162,29 +166,40 @@ class StreamingKMeansModel( * .trainOn(DStream) * }}} */ +@Since("1.2.0") @Experimental class StreamingKMeans( var k: Int, var decayFactor: Double, var timeUnit: String) extends Logging with Serializable { + @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) - /** Set the number of clusters. */ + /** + * Set the number of clusters. + */ + @Since("1.2.0") def setK(k: Int): this.type = { this.k = k this } - /** Set the decay factor directly (for forgetful algorithms). */ + /** + * Set the decay factor directly (for forgetful algorithms). + */ + @Since("1.2.0") def setDecayFactor(a: Double): this.type = { this.decayFactor = a this } - /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + /** + * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + */ + @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) @@ -195,7 +210,10 @@ class StreamingKMeans( this } - /** Specify initial centers directly. */ + /** + * Specify initial centers directly. + */ + @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { model = new StreamingKMeansModel(centers, weights) this @@ -208,6 +226,7 @@ class StreamingKMeans( * @param weight Weight for each center * @param seed Random seed */ + @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) @@ -216,7 +235,10 @@ class StreamingKMeans( this } - /** Return the latest model. */ + /** + * Return the latest model. + */ + @Since("1.2.0") def latestModel(): StreamingKMeansModel = { model } @@ -229,6 +251,7 @@ class StreamingKMeans( * * @param data DStream containing vector data */ + @Since("1.2.0") def trainOn(data: DStream[Vector]) { assertInitialized() data.foreachRDD { (rdd, time) => @@ -236,7 +259,10 @@ class StreamingKMeans( } } - /** Java-friendly version of `trainOn`. */ + /** + * Java-friendly version of `trainOn`. + */ + @Since("1.4.0") def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) /** @@ -245,12 +271,16 @@ class StreamingKMeans( * @param data DStream containing vector data * @return DStream containing predictions */ + @Since("1.2.0") def predictOn(data: DStream[Vector]): DStream[Int] = { assertInitialized() data.map(model.predict) } - /** Java-friendly version of `predictOn`. */ + /** + * Java-friendly version of `predictOn`. + */ + @Since("1.4.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) } @@ -262,12 +292,16 @@ class StreamingKMeans( * @tparam K key type * @return DStream containing the input keys and the predictions as values */ + @Since("1.2.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { assertInitialized() data.mapValues(model.predict) } - /** Java-friendly version of `predictOnValues`. */ + /** + * Java-friendly version of `predictOnValues`. + */ + @Since("1.4.0") def predictOnValues[K]( data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { implicit val tag = fakeClassTag[K] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index c1d1a224817e..76ae847921f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ @@ -42,6 +42,7 @@ import org.apache.spark.sql.DataFrame * be smaller as a result, meaning there may be an extra sample at * partition boundaries. */ +@Since("1.3.0") @Experimental class BinaryClassificationMetrics( val scoreAndLabels: RDD[(Double, Double)], @@ -52,6 +53,7 @@ class BinaryClassificationMetrics( /** * Defaults `numBins` to 0. */ + @Since("1.0.0") def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) /** @@ -61,12 +63,18 @@ class BinaryClassificationMetrics( private[mllib] def this(scoreAndLabels: DataFrame) = this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * Unpersist intermediate RDDs used in the computation. + */ + @Since("1.0.0") def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + */ + @Since("1.0.0") def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -75,6 +83,7 @@ class BinaryClassificationMetrics( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.0.0") def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) val sc = confusions.context @@ -86,6 +95,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. */ + @Since("1.0.0") def areaUnderROC(): Double = AreaUnderCurve.of(roc()) /** @@ -93,6 +103,7 @@ class BinaryClassificationMetrics( * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall */ + @Since("1.0.0") def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) val sc = confusions.context @@ -103,6 +114,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. */ + @Since("1.0.0") def areaUnderPR(): Double = AreaUnderCurve.of(pr()) /** @@ -111,15 +123,25 @@ class BinaryClassificationMetrics( * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score */ + @Since("1.0.0") def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + */ + @Since("1.0.0") def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + */ + @Since("1.0.0") def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + */ + @Since("1.0.0") def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 4628dc569091..02e89d921033 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation import scala.collection.Map import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame * * @param predictionAndLabels an RDD of (prediction, label) pairs. */ +@Since("1.1.0") @Experimental class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { @@ -65,6 +66,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * they are ordered by class label ascending, * as in "labels" */ + @Since("1.1.0") def confusionMatrix: Matrix = { val n = labels.size val values = Array.ofDim[Double](n * n) @@ -84,12 +86,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns true positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. */ + @Since("1.1.0") def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) fp.toDouble / (labelCount - labelCountByClass(label)) @@ -99,6 +103,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns precision for a given label (category) * @param label the label. */ + @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) val fp = fpByClass.getOrElse(label, 0) @@ -109,6 +114,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns recall for a given label (category) * @param label the label. */ + @Since("1.1.0") def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) /** @@ -116,6 +122,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * @param label the label. * @param beta the beta parameter. */ + @Since("1.1.0") def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) val r = recall(label) @@ -127,6 +134,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f1-measure for a given label (category) * @param label the label. */ + @Since("1.1.0") def fMeasure(label: Double): Double = fMeasure(label, 1.0) /** @@ -180,6 +188,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged f-measure * @param beta the beta parameter. */ + @Since("1.1.0") def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount }.sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index bf6eb1d5bd2a..a0a8d9c56847 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.sql.DataFrame @@ -26,6 +27,7 @@ import org.apache.spark.sql.DataFrame * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. */ +@Since("1.2.0") class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { /** @@ -104,6 +106,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns precision for a given label (category) * @param label the label. */ + @Since("1.2.0") def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) @@ -114,6 +117,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns recall for a given label (category) * @param label the label. */ + @Since("1.2.0") def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) @@ -124,6 +128,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns f1-measure for a given label (category) * @param label the label. */ + @Since("1.2.0") def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 5b5a2a1450f7..a7f43f0b110f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ +@Since("1.2.0") @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) extends Logging with Serializable { @@ -56,6 +57,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions */ + @Since("1.2.0") def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -125,6 +127,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions */ + @Since("1.2.0") def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -163,6 +166,7 @@ object RankingMetrics { * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs */ + @Since("1.4.0") def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 408847afa800..36a6c357c389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors @@ -30,6 +30,7 @@ import org.apache.spark.sql.DataFrame * * @param predictionAndObservations an RDD of (prediction, observation) pairs. */ +@Since("1.2.0") @Experimental class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { @@ -67,6 +68,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ + @Since("1.2.0") def explainedVariance: Double = { SSreg / summary.count } @@ -75,6 +77,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. */ + @Since("1.2.0") def meanAbsoluteError: Double = { summary.normL1(1) / summary.count } @@ -83,6 +86,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. */ + @Since("1.2.0") def meanSquaredError: Double = { SSerr / summary.count } @@ -91,6 +95,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the root mean squared error, which is defined as the square root of * the mean squared error. */ + @Since("1.2.0") def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) } @@ -99,6 +104,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ + @Since("1.2.0") def r2: Double = { 1 - SSerr / SStot } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 5f8c1dea237b..fdd974d7a391 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics @@ -31,8 +31,10 @@ import org.apache.spark.rdd.RDD * * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ +@Since("1.3.0") @Experimental -class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { +class ChiSqSelectorModel ( + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -52,6 +54,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.3.0") override def transform(vector: Vector): Vector = { compress(vector, selectedFeatures) } @@ -107,8 +110,10 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) */ +@Since("1.3.0") @Experimental -class ChiSqSelector (val numTopFeatures: Int) extends Serializable { +class ChiSqSelector ( + @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. @@ -117,6 +122,7 @@ class ChiSqSelector (val numTopFeatures: Int) extends Serializable { * Real-valued features will be treated as categorical for each distinct value. * Apply feature discretizer before using this function. */ + @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val indices = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index d67fe6c3ee4f..33e2d17bb472 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg._ /** @@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg._ * multiplier. * @param scalingVec The values used to scale the reference vector's individual components. */ +@Since("1.4.0") @Experimental class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { @@ -36,6 +37,7 @@ class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { require(vector.size == scalingVec.size, s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c53475818395..e47d524b6162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,7 +22,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -34,19 +34,25 @@ import org.apache.spark.util.Utils * * @param numFeatures number of features (default: 2^20^) */ +@Since("1.1.0") @Experimental class HashingTF(val numFeatures: Int) extends Serializable { + /** + */ + @Since("1.1.0") def this() = this(1 << 20) /** * Returns the index of the input term. */ + @Since("1.1.0") def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) /** * Transforms the input document into a sparse term frequency vector. */ + @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -59,6 +65,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document into a sparse term frequency vector (Java version). */ + @Since("1.1.0") def transform(document: JavaIterable[_]): Vector = { transform(document.asScala) } @@ -66,6 +73,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors. */ + @Since("1.1.0") def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { dataset.map(this.transform) } @@ -73,6 +81,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors (Java version). */ + @Since("1.1.0") def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { dataset.rdd.map(this.transform).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3fab7ea79bef..d5353ddd972e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * @param minDocFreq minimum of documents in which a term * should appear for filtering */ +@Since("1.1.0") @Experimental class IDF(val minDocFreq: Int) { @@ -48,6 +49,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: RDD[Vector]): IDFModel = { val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( minDocFreq = minDocFreq))( @@ -61,6 +63,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } @@ -171,6 +174,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: RDD[Vector]): RDD[Vector] = { val bcIdf = dataset.context.broadcast(idf) dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) @@ -182,6 +186,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param v a term frequency vector * @return a TF-IDF vector */ + @Since("1.3.0") def transform(v: Vector): Vector = IDFModel.transform(idf, v) /** @@ -189,6 +194,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset a JavaRDD of term frequency vectors * @return a JavaRDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { transform(dataset.rdd).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 32848e039eb8..0e070257d9fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** @@ -31,9 +31,11 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors * * @param p Normalization in L^p^ space, p = 2 by default. */ +@Since("1.1.0") @Experimental class Normalizer(p: Double) extends VectorTransformer { + @Since("1.1.0") def this() = this(2) require(p >= 1.0) @@ -44,6 +46,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @param vector vector to be normalized. * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { val norm = Vectors.norm(vector, p) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 2a66263d8b7d..a48b7bba665d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.feature +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -27,6 +28,7 @@ import org.apache.spark.rdd.RDD * * @param k number of principal components */ +@Since("1.4.0") class PCA(val k: Int) { require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") @@ -35,6 +37,7 @@ class PCA(val k: Int) { * * @param sources source vectors */ + @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { require(k <= sources.first().size, s"source vector size is ${sources.first().size} must be greater than k=$k") @@ -58,7 +61,10 @@ class PCA(val k: Int) { new PCAModel(k, pc) } - /** Java-friendly version of [[fit()]] */ + /** + * Java-friendly version of [[fit()]] + */ + @Since("1.4.0") def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) } @@ -76,6 +82,7 @@ class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTr * Vector must be the same length as the source vectors given to [[PCA.fit()]]. * @return transformed vector. Vector will be of length k. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { vector match { case dv: DenseVector => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index c73b8f258060..b95d5a899001 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -32,9 +32,11 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ +@Since("1.1.0") @Experimental class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { + @Since("1.1.0") def this() = this(false, true) if (!(withMean || withStd)) { @@ -47,6 +49,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param data The data used to compute the mean and variance to build the transformation model. * @return a StandardScalarModel */ + @Since("1.1.0") def fit(data: RDD[Vector]): StandardScalerModel = { // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( @@ -69,6 +72,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param withStd whether to scale the data to have unit standard deviation * @param withMean whether to center the data before scaling */ +@Since("1.1.0") @Experimental class StandardScalerModel ( val std: Vector, @@ -76,6 +80,9 @@ class StandardScalerModel ( var withStd: Boolean, var withMean: Boolean) extends VectorTransformer { + /** + */ + @Since("1.3.0") def this(std: Vector, mean: Vector) { this(std, mean, withStd = std != null, withMean = mean != null) require(this.withStd || this.withMean, @@ -86,8 +93,10 @@ class StandardScalerModel ( } } + @Since("1.3.0") def this(std: Vector) = this(std, null) + @Since("1.3.0") @DeveloperApi def setWithMean(withMean: Boolean): this.type = { require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") @@ -95,6 +104,7 @@ class StandardScalerModel ( this } + @Since("1.3.0") @DeveloperApi def setWithStd(withStd: Boolean): this.type = { require(!(withStd && this.std == null), @@ -115,6 +125,7 @@ class StandardScalerModel ( * @return Standardized vector. If the std of a column is zero, it will return default `0.0` * for the column with zero std. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 7358c1c84f79..5778fd1d0925 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for transformation of a vector */ +@Since("1.1.0") @DeveloperApi trait VectorTransformer extends Serializable { @@ -35,6 +36,7 @@ trait VectorTransformer extends Serializable { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.1.0") def transform(vector: Vector): Vector /** @@ -43,6 +45,7 @@ trait VectorTransformer extends Serializable { * @param data RDD[Vector] to be transformed. * @return transformed RDD[Vector]. */ + @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. // So it should be no longer necessary to explicitly broadcast `this` object. @@ -55,6 +58,7 @@ trait VectorTransformer extends Serializable { * @param data JavaRDD[Vector] to be transformed. * @return transformed JavaRDD[Vector]. */ + @Since("1.1.0") def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { transform(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index cbbd2b0c8d06..e6f45ae4b01d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -32,7 +32,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -70,6 +70,7 @@ private case class VocabWord( * and * Distributed Representations of Words and Phrases and their Compositionality. */ +@Since("1.1.0") @Experimental class Word2Vec extends Serializable with Logging { @@ -83,6 +84,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets vector size (default: 100). */ + @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { this.vectorSize = vectorSize this @@ -91,6 +93,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets initial learning rate (default: 0.025). */ + @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { this.learningRate = learningRate this @@ -99,6 +102,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets number of partitions (default: 1). Use a small number for accuracy. */ + @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") this.numPartitions = numPartitions @@ -109,6 +113,7 @@ class Word2Vec extends Serializable with Logging { * Sets number of iterations (default: 1), which should be smaller than or equal to number of * partitions. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.numIterations = numIterations this @@ -117,6 +122,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets random seed (default: a random long integer). */ + @Since("1.1.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -126,6 +132,7 @@ class Word2Vec extends Serializable with Logging { * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ + @Since("1.3.0") def setMinCount(minCount: Int): this.type = { this.minCount = minCount this @@ -263,6 +270,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset an RDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -412,6 +420,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset a JavaRDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { fit(dataset.rdd.map(_.asScala)) } @@ -454,6 +463,7 @@ class Word2VecModel private[mllib] ( wordVecNorms } + @Since("1.5.0") def this(model: Map[String, Array[Float]]) = { this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } @@ -469,6 +479,7 @@ class Word2VecModel private[mllib] ( override protected def formatVersion = "1.0" + @Since("1.4.0") def save(sc: SparkContext, path: String): Unit = { Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } @@ -478,6 +489,7 @@ class Word2VecModel private[mllib] ( * @param word a word * @return vector representation of word */ + @Since("1.1.0") def transform(word: String): Vector = { wordIndex.get(word) match { case Some(ind) => @@ -494,6 +506,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) findSynonyms(vector, num) @@ -505,6 +518,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k @@ -534,6 +548,7 @@ class Word2VecModel private[mllib] ( /** * Returns a map of words to their vector representations. */ + @Since("1.2.0") def getVectors: Map[String, Array[Float]] = { wordIndex.map { case (word, ind) => (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) @@ -541,6 +556,7 @@ class Word2VecModel private[mllib] ( } } +@Since("1.4.0") @Experimental object Word2VecModel extends Loader[Word2VecModel] { @@ -600,6 +616,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } } + @Since("1.4.0") override def load(sc: SparkContext, path: String): Word2VecModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 72d0ea0c12e1..ba3b447a8339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.mllib.fpm +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.AssociationRules.Rule @@ -32,24 +33,22 @@ import org.apache.spark.rdd.RDD * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates * association rules which have a single item as the consequent. * - * @since 1.5.0 */ +@Since("1.5.0") @Experimental class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minConfidence = 0.8}. - * - * @since 1.5.0 */ + @Since("1.5.0") def this() = this(0.8) /** * Sets the minimal confidence (default: `0.8`). - * - * @since 1.5.0 */ + @Since("1.5.0") def setMinConfidence(minConfidence: Double): this.type = { require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence @@ -61,8 +60,8 @@ class AssociationRules private[fpm] ( * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] * @return a [[Set[Rule[Item]]] containing the assocation rules. * - * @since 1.5.0 */ + @Since("1.5.0") def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => @@ -95,12 +94,14 @@ object AssociationRules { * :: Experimental :: * * An association rule between sets of items. - * @param antecedent hypotheses of the rule - * @param consequent conclusion of the rule + * @param antecedent hypotheses of the rule. Java users should call [[Rule#javaAntecedent]] + * instead. + * @param consequent conclusion of the rule. Java users should call [[Rule#javaConsequent]] + * instead. * @tparam Item item type * - * @since 1.5.0 */ + @Since("1.5.0") @Experimental class Rule[Item] private[fpm] ( val antecedent: Array[Item], @@ -108,6 +109,11 @@ object AssociationRules { freqUnion: Double, freqAntecedent: Double) extends Serializable { + /** + * Returns the confidence of the rule. + * + */ + @Since("1.5.0") def confidence: Double = freqUnion.toDouble / freqAntecedent require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { @@ -115,5 +121,23 @@ object AssociationRules { s"A valid association rule must have disjoint antecedent and " + s"consequent but ${sharedItems} is present in both." }) + + /** + * Returns antecedent in a Java List. + * + */ + @Since("1.5.0") + def javaAntecedent: java.util.List[Item] = { + antecedent.toList.asJava + } + + /** + * Returns consequent in a Java List. + * + */ + @Since("1.5.0") + def javaConsequent: java.util.List[Item] = { + consequent.toList.asJava + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index e2370a52f493..e37f80627168 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ @@ -39,15 +39,15 @@ import org.apache.spark.storage.StorageLevel * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced - * @since 1.5.0 */ + @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) @@ -71,8 +71,8 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning * (Wikipedia)]] * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class FPGrowth private ( private var minSupport: Double, @@ -82,15 +82,15 @@ class FPGrowth private ( * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same * as the input data}. * - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(0.3, -1) /** * Sets the minimal support level (default: `0.3`). * - * @since 1.3.0 */ + @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { this.minSupport = minSupport this @@ -99,8 +99,8 @@ class FPGrowth private ( /** * Sets the number of partitions used by parallel FP-growth (default: same as input data). * - * @since 1.3.0 */ + @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { this.numPartitions = numPartitions this @@ -111,8 +111,8 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] * - * @since 1.3.0 */ + @Since("1.3.0") def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") @@ -213,8 +213,8 @@ class FPGrowth private ( /** * :: Experimental :: * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental object FPGrowth { @@ -224,15 +224,15 @@ object FPGrowth { * @param freq frequency * @tparam Item item type * - * @since 1.3.0 */ + @Since("1.3.0") class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { /** * Returns items in a Java List. * - * @since 1.3.0 */ + @Since("1.3.0") def javaItems: java.util.List[Item] = { items.toList.asJava } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ad6715b52f33..dc4ae1d0b69e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -282,25 +282,30 @@ object PrefixSpan extends Logging { largePrefixes = newLargePrefixes } - // Switch to local processing. - val bcSmallPrefixes = sc.broadcast(smallPrefixes) - val distributedFreqPattern = postfixes.flatMap { postfix => - bcSmallPrefixes.value.values.map { prefix => - (prefix.id, postfix.project(prefix).compressed) - }.filter(_._2.nonEmpty) - }.groupByKey().flatMap { case (id, projPostfixes) => - val prefix = bcSmallPrefixes.value(id) - val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) - // TODO: We collect projected postfixes into memory. We should also compare the performance - // TODO: of keeping them on shuffle files. - localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => - (prefix.items ++ pattern, count) + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern } - // Union local frequent patterns and distributed ones. - val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern) - .persist(StorageLevel.MEMORY_AND_DISK) freqPatterns } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9029093e0fa0..bbbcc8436b7c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -469,7 +469,7 @@ private[spark] object BLAS extends Serializable with Logging { require(A.numCols == x.size, s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, - s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1c858348bf20..28b5b4637bf1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -228,6 +228,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) class DenseMatrix( val numRows: Int, @@ -253,12 +254,12 @@ class DenseMatrix( * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.3.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case m: Matrix => toBreeze == m.toBreeze case _ => false } @@ -277,6 +278,7 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) + @Since("1.3.0") override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { @@ -287,6 +289,7 @@ class DenseMatrix( values(index(i, j)) = v } + @Since("1.4.0") override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), @@ -302,6 +305,7 @@ class DenseMatrix( this } + @Since("1.3.0") override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { @@ -332,14 +336,17 @@ class DenseMatrix( } } + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) + @Since("1.5.0") override def numActives: Int = values.length /** * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("1.3.0") def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -367,6 +374,7 @@ class DenseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. */ +@Since("1.3.0") object DenseMatrix { /** @@ -375,6 +383,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ + @Since("1.3.0") def zeros(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -387,6 +396,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ + @Since("1.3.0") def ones(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -398,6 +408,7 @@ object DenseMatrix { * @param n number of rows and columns of the matrix * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def eye(n: Int): DenseMatrix = { val identity = DenseMatrix.zeros(n, n) var i = 0 @@ -415,6 +426,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -428,6 +440,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -440,6 +453,7 @@ object DenseMatrix { * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.3.0") def diag(vector: Vector): DenseMatrix = { val n = vector.size val matrix = DenseMatrix.zeros(n, n) @@ -475,6 +489,7 @@ object DenseMatrix { * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ +@Since("1.2.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) class SparseMatrix( val numRows: Int, @@ -512,6 +527,7 @@ class SparseMatrix( * order for each column * @param values non-zero matrix entries in column major */ + @Since("1.3.0") def this( numRows: Int, numCols: Int, @@ -519,6 +535,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) @@ -528,6 +549,9 @@ class SparseMatrix( } } + /** + */ + @Since("1.3.0") override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) @@ -551,6 +575,7 @@ class SparseMatrix( } } + @Since("1.4.0") override def copy: SparseMatrix = { new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } @@ -568,6 +593,7 @@ class SparseMatrix( this } + @Since("1.3.0") override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) @@ -602,12 +628,15 @@ class SparseMatrix( * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("1.3.0") def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) + @Since("1.5.0") override def numActives: Int = values.length } @@ -615,6 +644,7 @@ class SparseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. */ +@Since("1.3.0") object SparseMatrix { /** @@ -626,6 +656,7 @@ object SparseMatrix { * @param entries Array of (i, j, value) tuples * @return The corresponding `SparseMatrix` */ + @Since("1.3.0") def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) val numEntries = sortedEntries.size @@ -674,6 +705,7 @@ object SparseMatrix { * @param n number of rows and columns of the matrix * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): SparseMatrix = { new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) } @@ -743,6 +775,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextDouble()) @@ -756,6 +789,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextGaussian()) @@ -767,6 +801,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ + @Since("1.3.0") def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { @@ -783,6 +818,7 @@ object SparseMatrix { /** * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]]. */ +@Since("1.0.0") object Matrices { /** @@ -792,6 +828,7 @@ object Matrices { * @param numCols number of columns * @param values matrix entries in column major */ + @Since("1.0.0") def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { new DenseMatrix(numRows, numCols, values) } @@ -805,6 +842,7 @@ object Matrices { * @param rowIndices the row index of the entry * @param values non-zero matrix entries in column major */ + @Since("1.2.0") def sparse( numRows: Int, numCols: Int, @@ -838,6 +876,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ + @Since("1.2.0") def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** @@ -846,6 +885,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of ones */ + @Since("1.2.0") def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** @@ -853,6 +893,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.2.0") def eye(n: Int): Matrix = DenseMatrix.eye(n) /** @@ -860,6 +901,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("1.3.0") def speye(n: Int): Matrix = SparseMatrix.speye(n) /** @@ -869,6 +911,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.2.0") def rand(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.rand(numRows, numCols, rng) @@ -880,6 +923,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprand(numRows, numCols, density, rng) @@ -890,6 +934,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.2.0") def randn(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.randn(numRows, numCols, rng) @@ -901,6 +946,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprandn(numRows, numCols, density, rng) @@ -910,6 +956,7 @@ object Matrices { * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("1.2.0") def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) /** @@ -919,6 +966,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ + @Since("1.3.0") def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) @@ -977,6 +1025,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were vertically concatenated */ + @Since("1.3.0") def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b416d50a5631..a37aca99d5e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Represents singular value decomposition (SVD) factors. */ +@Since("1.0.0") @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) @@ -31,5 +32,5 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * Represents QR factors. */ @Experimental -case class QRDecomposition[UType, VType](Q: UType, R: VType) +case class QRDecomposition[QType, RType](Q: QType, R: RType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 86c461fa9163..3d577edbe23e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -159,15 +159,13 @@ sealed trait Vector extends Serializable { } /** - * :: DeveloperApi :: + * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@DeveloperApi -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense @@ -243,11 +241,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { * We don't use the name `Vector` because Scala imports * [[scala.collection.immutable.Vector]] by default. */ +@Since("1.0.0") object Vectors { /** * Creates a dense vector from its values. */ + @Since("1.0.0") @varargs def dense(firstValue: Double, otherValues: Double*): Vector = new DenseVector((firstValue +: otherValues).toArray) @@ -256,6 +256,7 @@ object Vectors { /** * Creates a dense vector from a double array. */ + @Since("1.0.0") def dense(values: Array[Double]): Vector = new DenseVector(values) /** @@ -265,6 +266,7 @@ object Vectors { * @param indices index array, must be strictly increasing. * @param values value array, must have the same length as indices. */ + @Since("1.0.0") def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector = new SparseVector(size, indices, values) @@ -274,6 +276,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { require(size > 0, "The size of the requested sparse vector must be greater than 0.") @@ -295,6 +298,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("1.0.0") def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = { sparse(size, elements.asScala.map { case (i, x) => (i.intValue(), x.doubleValue()) @@ -307,6 +311,7 @@ object Vectors { * @param size vector size * @return a zero vector */ + @Since("1.1.0") def zeros(size: Int): Vector = { new DenseVector(new Array[Double](size)) } @@ -314,6 +319,7 @@ object Vectors { /** * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. */ + @Since("1.1.0") def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) } @@ -357,6 +363,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ + @Since("1.3.0") def norm(vector: Vector, p: Double): Double = { require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + s"You specified p=$p.") @@ -409,6 +416,7 @@ object Vectors { * @param v2 second Vector. * @return squared distance between two Vectors. */ + @Since("1.3.0") def sqdist(v1: Vector, v2: Vector): Double = { require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + s"=${v2.size}.") @@ -522,19 +530,24 @@ object Vectors { /** * A dense vector represented by a value array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { + @Since("1.0.0") override def size: Int = values.length override def toString: String = values.mkString("[", ",", "]") + @Since("1.0.0") override def toArray: Array[Double] = values private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) + @Since("1.0.0") override def apply(i: Int): Double = values(i) + @Since("1.1.0") override def copy: DenseVector = { new DenseVector(values.clone()) } @@ -566,8 +579,10 @@ class DenseVector(val values: Array[Double]) extends Vector { result } + @Since("1.4.0") override def numActives: Int = size + @Since("1.4.0") override def numNonzeros: Int = { // same as values.count(_ != 0.0) but faster var nnz = 0 @@ -579,6 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector { nnz } + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros val ii = new Array[Int](nnz) @@ -594,6 +610,7 @@ class DenseVector(val values: Array[Double]) extends Vector { new SparseVector(size, ii, vv) } + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -613,6 +630,7 @@ class DenseVector(val values: Array[Double]) extends Vector { } } +@Since("1.3.0") object DenseVector { /** Extracts the value array from a dense vector. */ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) @@ -625,6 +643,7 @@ object DenseVector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, @@ -640,6 +659,7 @@ class SparseVector( override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" + @Since("1.0.0") override def toArray: Array[Double] = { val data = new Array[Double](size) var i = 0 @@ -651,6 +671,7 @@ class SparseVector( data } + @Since("1.1.0") override def copy: SparseVector = { new SparseVector(size, indices.clone(), values.clone()) } @@ -691,8 +712,10 @@ class SparseVector( result } + @Since("1.4.0") override def numActives: Int = values.length + @Since("1.4.0") override def numNonzeros: Int = { var nnz = 0 values.foreach { v => @@ -703,6 +726,7 @@ class SparseVector( nnz } + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros if (nnz == numActives) { @@ -722,6 +746,7 @@ class SparseVector( } } + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -792,6 +817,7 @@ class SparseVector( } } +@Since("1.3.0") object SparseVector { def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 3323ae7b1fba..94376c24a7ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -129,6 +129,7 @@ private[mllib] object GridPartitioner { * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to * zero, the number of columns will be calculated when `numCols` is invoked. */ +@Since("1.3.0") @Experimental class BlockMatrix( val blocks: RDD[((Int, Int), Matrix)], @@ -150,6 +151,7 @@ class BlockMatrix( * @param colsPerBlock Number of columns that make up each block. The blocks forming the final * columns are not required to have the given number of columns */ + @Since("1.3.0") def this( blocks: RDD[((Int, Int), Matrix)], rowsPerBlock: Int, @@ -157,11 +159,13 @@ class BlockMatrix( this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) } + @Since("1.3.0") override def numRows(): Long = { if (nRows <= 0L) estimateDim() nRows } + @Since("1.3.0") override def numCols(): Long = { if (nCols <= 0L) estimateDim() nCols @@ -193,6 +197,7 @@ class BlockMatrix( * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if * any error is found. */ + @Since("1.3.0") def validate(): Unit = { logDebug("Validating BlockMatrix...") // check if the matrix is larger than the claimed dimensions @@ -229,18 +234,21 @@ class BlockMatrix( } /** Caches the underlying RDD. */ + @Since("1.3.0") def cache(): this.type = { blocks.cache() this } /** Persists the underlying RDD with the specified storage level. */ + @Since("1.3.0") def persist(storageLevel: StorageLevel): this.type = { blocks.persist(storageLevel) this } /** Converts to CoordinateMatrix. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => val rowStart = blockRowIndex.toLong * rowsPerBlock @@ -255,6 +263,7 @@ class BlockMatrix( } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.3.0") def toIndexedRowMatrix(): IndexedRowMatrix = { require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + s"numCols: ${numCols()}") @@ -263,6 +272,7 @@ class BlockMatrix( } /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + @Since("1.3.0") def toLocalMatrix(): Matrix = { require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + s"Int.MaxValue. Currently numRows: ${numRows()}") @@ -287,8 +297,11 @@ class BlockMatrix( new DenseMatrix(m, n, values) } - /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the - * same underlying data. Is a lazy operation. */ + /** + * Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. + */ + @Since("1.3.0") def transpose: BlockMatrix = { val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => ((blockColIndex, blockRowIndex), mat.transpose) @@ -302,12 +315,14 @@ class BlockMatrix( new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) } - /** Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. - */ + /** + * Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + @Since("1.3.0") def add(other: BlockMatrix): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") @@ -335,12 +350,14 @@ class BlockMatrix( } } - /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` - * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains - * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output - * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause - * some performance issues until support for multiplying two sparse matrices is added. - */ + /** + * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 078d1fac4444..4bb27ec84090 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} @@ -30,6 +30,7 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} * @param j column index * @param value value of the entry */ +@Since("1.0.0") @Experimental case class MatrixEntry(i: Long, j: Long, value: Double) @@ -43,6 +44,7 @@ case class MatrixEntry(i: Long, j: Long, value: Double) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the max column index plus one. */ +@Since("1.0.0") @Experimental class CoordinateMatrix( val entries: RDD[MatrixEntry], @@ -50,9 +52,11 @@ class CoordinateMatrix( private var nCols: Long) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0L) { computeSize() @@ -61,6 +65,7 @@ class CoordinateMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { computeSize() @@ -69,11 +74,13 @@ class CoordinateMatrix( } /** Transposes this CoordinateMatrix. */ + @Since("1.3.0") def transpose(): CoordinateMatrix = { new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) } /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.0.0") def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() if (nl > Int.MaxValue) { @@ -93,11 +100,13 @@ class CoordinateMatrix( * Converts to RowMatrix, dropping row indices after grouping by row index. * The number of columns must be within the integer range. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { toIndexedRowMatrix().toRowMatrix() } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -110,6 +119,7 @@ class CoordinateMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { require(rowsPerBlock > 0, s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index a0e26ce3bc46..e51327ebb7b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -19,9 +19,12 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.annotation.Since + /** * Represents a distributively stored matrix backed by one or more RDDs. */ +@Since("1.0.0") trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 1c33b43ea7a8..6d2c05a47d04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition @@ -28,6 +28,7 @@ import org.apache.spark.mllib.linalg.SingularValueDecomposition * :: Experimental :: * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */ +@Since("1.0.0") @Experimental case class IndexedRow(index: Long, vector: Vector) @@ -42,6 +43,7 @@ case class IndexedRow(index: Long, vector: Vector) * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ +@Since("1.0.0") @Experimental class IndexedRowMatrix( val rows: RDD[IndexedRow], @@ -49,8 +51,10 @@ class IndexedRowMatrix( private var nCols: Int) extends DistributedMatrix { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0) + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { // Calling `first` will throw an exception if `rows` is empty. @@ -59,6 +63,7 @@ class IndexedRowMatrix( nCols } + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { // Reduce will throw an exception if `rows` is empty. @@ -71,11 +76,13 @@ class IndexedRowMatrix( * Drops row indices and converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { new RowMatrix(rows.map(_.vector), 0L, nCols) } /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -88,6 +95,7 @@ class IndexedRowMatrix( * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { // TODO: This implementation may be optimized toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) @@ -97,6 +105,7 @@ class IndexedRowMatrix( * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entries = rows.flatMap { row => val rowIndex = row.index @@ -133,6 +142,7 @@ class IndexedRowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V) */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -159,6 +169,7 @@ class IndexedRowMatrix( * @param B a local matrix whose number of rows must match the number of columns of this matrix * @return an IndexedRowMatrix representing the product, which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { val mat = toRowMatrix().multiply(B) val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => @@ -170,6 +181,7 @@ class IndexedRowMatrix( /** * Computes the Gramian matrix `A^T A`. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { toRowMatrix().computeGramianMatrix() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index bfc90c9ef852..78036eba5c3e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -28,7 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -45,6 +45,7 @@ import org.apache.spark.storage.StorageLevel * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. */ +@Since("1.0.0") @Experimental class RowMatrix( val rows: RDD[Vector], @@ -52,9 +53,11 @@ class RowMatrix( private var nCols: Int) extends DistributedMatrix with Logging { /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[Vector]) = this(rows, 0L, 0) /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { try { @@ -70,6 +73,7 @@ class RowMatrix( } /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { nRows = rows.count() @@ -108,6 +112,7 @@ class RowMatrix( /** * Computes the Gramian matrix `A^T A`. */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -178,6 +183,7 @@ class RowMatrix( * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -318,6 +324,7 @@ class RowMatrix( * Computes the covariance matrix, treating each row as an observation. * @return a local dense matrix of size n x n */ + @Since("1.0.0") def computeCovariance(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -371,6 +378,7 @@ class RowMatrix( * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ + @Since("1.0.0") def computePrincipalComponents(k: Int): Matrix = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") @@ -389,6 +397,7 @@ class RowMatrix( /** * Computes column-wise summary statistics. */ + @Since("1.0.0") def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), @@ -404,6 +413,7 @@ class RowMatrix( * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product, * which preserves partitioning */ + @Since("1.0.0") def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt val k = B.numCols @@ -436,6 +446,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities between * columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(): CoordinateMatrix = { columnSimilarities(0.0) } @@ -479,6 +490,7 @@ class RowMatrix( * @return An n x n sparse upper-triangular matrix of cosine similarities * between columns of this matrix. */ + @Since("1.2.0") def columnSimilarities(threshold: Double): CoordinateMatrix = { require(threshold >= 0, s"Threshold cannot be negative: $threshold") @@ -656,6 +668,7 @@ class RowMatrix( } } +@Since("1.0.0") @Experimental object RowMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 56c549ef99cb..b27ef1b949e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD @@ -26,8 +26,8 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. - * @since 0.8.0 */ +@Since("0.8.0") case class Rating(user: Int, product: Int, rating: Double) /** @@ -255,8 +255,8 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. - * @since 0.8.0 */ +@Since("0.8.0") object ALS { /** * Train a matrix factorization model given an RDD of ratings given by users to some products, @@ -271,8 +271,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed - * @since 0.9.1 */ + @Since("0.9.1") def train( ratings: RDD[Rating], rank: Int, @@ -296,8 +296,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @since 0.8.0 */ + @Since("0.8.0") def train( ratings: RDD[Rating], rank: Int, @@ -319,8 +319,8 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { train(ratings, rank, iterations, lambda, -1) @@ -336,8 +336,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { train(ratings, rank, iterations, 0.01, -1) @@ -357,8 +357,8 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -384,8 +384,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -409,8 +409,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, lambda, -1, alpha) @@ -427,8 +427,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 261ca9cef0c5..ba4cfdcd9f1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -30,6 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ @@ -49,8 +50,8 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. - * @since 0.8.0 */ +@Since("0.8.0") class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], @@ -74,9 +75,8 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. - * @since 0.8.0 - */ + /** Predict the rating of one user for one product. */ + @Since("0.8.0") def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -114,8 +114,8 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. - * @since 0.9.0 */ + @Since("0.9.0") def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. // So if the usersProducts given for prediction contains only few products or @@ -146,8 +146,8 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. - * @since 1.2.0 */ + @Since("1.2.0") def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() } @@ -162,8 +162,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) @@ -179,8 +179,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) @@ -199,8 +199,8 @@ class MatrixFactorizationModel( * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -212,8 +212,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API - * @since 1.4.0 */ + @Since("1.4.0") def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { case (user, top) => @@ -230,8 +230,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API - * @since 1.4.0 */ + @Since("1.4.0") def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { case (product, top) => @@ -241,9 +241,7 @@ class MatrixFactorizationModel( } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -326,8 +324,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 6709bd79bc82..509f6a2d169c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD @@ -34,7 +34,9 @@ import org.apache.spark.storage.StorageLevel * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { @@ -53,7 +55,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. @@ -71,7 +75,9 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } @@ -88,14 +94,20 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * :: DeveloperApi :: * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. + * */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - /** The optimizer to solve the problem. */ + /** + * The optimizer to solve the problem. + * + */ + @Since("1.0.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ @@ -130,7 +142,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The dimension of training features. + * */ + @Since("1.4.0") def getNumFeatures: Int = this.numFeatures /** @@ -153,13 +167,17 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Get if the algorithm uses addIntercept + * */ + @Since("1.4.0") def isAddIntercept: Boolean = this.addIntercept /** * Set if the algorithm should add an intercept. Default false. * We set the default to false because adding the intercept will cause memory allocation. + * */ + @Since("0.8.0") def setIntercept(addIntercept: Boolean): this.type = { this.addIntercept = addIntercept this @@ -167,7 +185,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Set if the algorithm should validate data before training. Default true. + * */ + @Since("0.8.0") def setValidateData(validateData: Boolean): this.type = { this.validateData = validateData this @@ -176,7 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. + * */ + @Since("0.8.0") def run(input: RDD[LabeledPoint]): M = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() @@ -208,7 +230,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. + * */ + @Since("1.0.0") def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { if (numFeatures < 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index f3b46c75c05f..31ca7c2f207d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -29,7 +29,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -46,7 +46,9 @@ import org.apache.spark.sql.SQLContext * @param predictions Array of predictions associated to the boundaries at the same index. * Results of isotonic regression and therefore monotone. * @param isotonic indicates whether this is isotonic or antitonic. + * */ +@Since("1.3.0") @Experimental class IsotonicRegressionModel ( val boundaries: Array[Double], @@ -59,7 +61,11 @@ class IsotonicRegressionModel ( assertOrdered(boundaries) assertOrdered(predictions)(predictionOrd) - /** A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. */ + /** + * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. + * + */ + @Since("1.4.0") def this(boundaries: java.lang.Iterable[Double], predictions: java.lang.Iterable[Double], isotonic: java.lang.Boolean) = { @@ -83,7 +89,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: RDD[Double]): RDD[Double] = { testData.map(predict) } @@ -94,7 +102,9 @@ class IsotonicRegressionModel ( * * @param testData Features to be labeled. * @return Predicted labels. + * */ + @Since("1.3.0") def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) } @@ -114,7 +124,9 @@ class IsotonicRegressionModel ( * 3) If testData falls between two values in boundary array then prediction is treated * as piecewise linear function and interpolated value is returned. In case there are * multiple values with the same boundary then the same rules as in 2) are used. + * */ + @Since("1.3.0") def predict(testData: Double): Double = { def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { @@ -148,6 +160,7 @@ class IsotonicRegressionModel ( /** A convenient method for boundaries called by the Python API. */ private[mllib] def predictionVector: Vector = Vectors.dense(predictions) + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) } @@ -155,6 +168,7 @@ class IsotonicRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.4.0") object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { import org.apache.spark.mllib.util.Loader._ @@ -200,6 +214,9 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } } + /** + */ + @Since("1.4.0") override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index d5fea822ad77..f7fe1b7b21fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -28,7 +29,9 @@ import org.apache.spark.SparkException * * @param label Label for this data point. * @param features List of features for this data point. + * */ +@Since("0.8.0") @BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { @@ -38,12 +41,16 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ +@Since("1.1.0") object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. + * */ + @Since("1.1.0") def parse(s: String): LabeledPoint = { if (s.startsWith("(")) { NumericParser.parse(s) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 4f482384f0f3..556411a366bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -30,7 +31,9 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") class LassoModel ( override val weights: Vector, override val intercept: Double) @@ -44,6 +47,7 @@ class LassoModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -51,8 +55,10 @@ class LassoModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object LassoModel extends Loader[LassoModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): LassoModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -106,7 +112,9 @@ class LassoWithSGD private ( /** * Top-level methods for calling Lasso. + * */ +@Since("0.8.0") object LassoWithSGD { /** @@ -123,7 +131,9 @@ object LassoWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -146,7 +156,9 @@ object LassoWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -167,7 +179,9 @@ object LassoWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -185,7 +199,9 @@ object LassoWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 9453c4f66c21..00ab06e3ba73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -30,7 +31,9 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) @@ -44,6 +47,7 @@ class LinearRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -51,8 +55,10 @@ class LinearRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object LinearRegressionModel extends Loader[LinearRegressionModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): LinearRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -105,7 +111,9 @@ class LinearRegressionWithSGD private[mllib] ( /** * Top-level methods for calling LinearRegression. + * */ +@Since("0.8.0") object LinearRegressionWithSGD { /** @@ -121,7 +129,9 @@ object LinearRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -142,7 +152,9 @@ object LinearRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -161,7 +173,9 @@ object LinearRegressionWithSGD { * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -178,7 +192,9 @@ object LinearRegressionWithSGD { * matrix A as well as the corresponding right hand side label y * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LinearRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 214ac4d0ed7d..0e72d6591ce8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,11 +19,12 @@ package org.apache.spark.mllib.regression import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD +@Since("0.8.0") @Experimental trait RegressionModel extends Serializable { /** @@ -31,7 +32,9 @@ trait RegressionModel extends Serializable { * * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -39,14 +42,18 @@ trait RegressionModel extends Serializable { * * @param testData array representing a single data point * @return Double prediction from the trained model + * */ + @Since("1.0.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + * */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 7d28ffad45c9..21a791d98b2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -31,7 +32,9 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. + * */ +@Since("0.8.0") class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) @@ -45,6 +48,7 @@ class RidgeRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -52,8 +56,10 @@ class RidgeRegressionModel ( override protected def formatVersion: String = "1.0" } +@Since("1.3.0") object RidgeRegressionModel extends Loader[RidgeRegressionModel] { + @Since("1.3.0") override def load(sc: SparkContext, path: String): RidgeRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -108,7 +114,9 @@ class RidgeRegressionWithSGD private ( /** * Top-level methods for calling RidgeRegression. + * */ +@Since("0.8.0") object RidgeRegressionWithSGD { /** @@ -124,7 +132,9 @@ object RidgeRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -146,7 +156,9 @@ object RidgeRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -166,7 +178,9 @@ object RidgeRegressionWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -183,7 +197,9 @@ object RidgeRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. + * */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 141052ba813e..cd3ed8a1549d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} @@ -53,7 +53,9 @@ import org.apache.spark.streaming.dstream.DStream * It is also ok to call trainOn on different streams; this will update * the model using each of the different sources, in sequence. * + * */ +@Since("1.1.0") @DeveloperApi abstract class StreamingLinearAlgorithm[ M <: GeneralizedLinearModel, @@ -65,7 +67,11 @@ abstract class StreamingLinearAlgorithm[ /** The algorithm to use for updating. */ protected val algorithm: A - /** Return the latest model. */ + /** + * Return the latest model. + * + */ + @Since("1.1.0") def latestModel(): M = { model.get } @@ -77,7 +83,9 @@ abstract class StreamingLinearAlgorithm[ * batch of data from the stream. * * @param data DStream containing labeled data + * */ + @Since("1.3.0") def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") @@ -95,7 +103,11 @@ abstract class StreamingLinearAlgorithm[ } } - /** Java-friendly version of `trainOn`. */ + /** + * Java-friendly version of `trainOn`. + * + */ + @Since("1.3.0") def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) /** @@ -103,7 +115,9 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing feature vectors * @return DStream containing predictions + * */ + @Since("1.1.0") def predictOn(data: DStream[Vector]): DStream[Double] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") @@ -111,7 +125,11 @@ abstract class StreamingLinearAlgorithm[ data.map{x => model.get.predict(x)} } - /** Java-friendly version of `predictOn`. */ + /** + * Java-friendly version of `predictOn`. + * + */ + @Since("1.1.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } @@ -121,7 +139,9 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing feature vectors * @tparam K key type * @return DStream containing the input keys and the predictions as values + * */ + @Since("1.1.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") @@ -130,7 +150,11 @@ abstract class StreamingLinearAlgorithm[ } - /** Java-friendly version of `predictOnValues`. */ + /** + * Java-friendly version of `predictOnValues`. + * + */ + @Since("1.3.0") def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { implicit val tag = fakeClassTag[K] JavaPairDStream.fromPairDStream( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index c6d04464a12b..537a05274eec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -39,7 +39,6 @@ import org.apache.spark.mllib.linalg.Vector * .setNumIterations(10) * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) - * */ @Experimental class StreamingLinearRegressionWithSGD private[mllib] ( @@ -61,31 +60,41 @@ class StreamingLinearRegressionWithSGD private[mllib] ( protected var model: Option[LinearRegressionModel] = None - /** Set the step size for gradient descent. Default: 0.1. */ + /** + * Set the step size for gradient descent. Default: 0.1. + */ def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } - /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + /** + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } - /** Set the fraction of each batch to use for updates. Default: 1.0. */ + /** + * Set the fraction of each batch to use for updates. Default: 1.0. + */ def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } - /** Set the initial weights. */ + /** + * Set the initial weights. + */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } - /** Set the convergence tolerance. */ + /** + * Set the convergence tolerance. + */ def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 93a6753efd4d..4a856f7f3434 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD @@ -37,8 +37,8 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} - * @since 1.4.0 */ +@Since("1.4.0") @Experimental class KernelDensity extends Serializable { @@ -52,8 +52,8 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). - * @since 1.4.0 */ + @Since("1.4.0") def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") this.bandwidth = bandwidth @@ -62,8 +62,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: RDD[Double]): this.type = { this.sample = sample this @@ -71,8 +71,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] this @@ -80,8 +80,8 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. - * @since 1.4.0 */ + @Since("1.4.0") def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample val bandwidth = this.bandwidth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 62da9f2ef22a..51b713e263e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector} /** @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. - * @since 1.1.0 */ +@Since("1.1.0") @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -53,8 +53,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ + @Since("1.1.0") def add(sample: Vector): this.type = { if (n == 0) { require(sample.size > 0, s"Vector should have dimension larger than zero.") @@ -109,8 +109,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ + @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + @@ -153,8 +153,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Sample mean of each dimension. + * */ + @Since("1.1.0") override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -168,8 +170,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Sample variance of each dimension. + * */ + @Since("1.1.0") override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -193,13 +197,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Sample size. + * */ + @Since("1.1.0") override def count: Long = totalCnt /** - * @since 1.1.0 + * Number of nonzero elements in each dimension. + * */ + @Since("1.1.0") override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -207,8 +215,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Maximum value of each dimension. + * */ + @Since("1.1.0") override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -221,8 +231,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.1.0 + * Minimum value of each dimension. + * */ + @Since("1.1.0") override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -235,8 +247,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.2.0 + * L2 (Euclidian) norm of each dimension. + * */ + @Since("1.2.0") override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -252,8 +266,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** - * @since 1.2.0 + * L1 norm of each dimension. + * */ + @Since("1.2.0") override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 3bb49f12289e..39a16fb743d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -17,59 +17,60 @@ package org.apache.spark.mllib.stat +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. - * @since 1.0.0 */ +@Since("1.0.0") trait MultivariateStatisticalSummary { /** * Sample mean vector. - * @since 1.0.0 */ + @Since("1.0.0") def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. - * @since 1.0.0 */ + @Since("1.0.0") def variance: Vector /** * Sample size. - * @since 1.0.0 */ + @Since("1.0.0") def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. - * @since 1.0.0 */ + @Since("1.0.0") def numNonzeros: Vector /** * Maximum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def max: Vector /** * Minimum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def min: Vector /** * Euclidean magnitude of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL2: Vector /** * L1 norm of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f84502919e38..84d64a5bfb38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -32,8 +32,8 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. - * @since 1.1.0 */ +@Since("1.1.0") @Experimental object Statistics { @@ -42,8 +42,8 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. - * @since 1.1.0 */ + @Since("1.1.0") def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } @@ -54,8 +54,8 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** @@ -71,8 +71,8 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** @@ -85,14 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -109,14 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -133,8 +133,8 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } @@ -148,8 +148,8 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** @@ -159,8 +159,8 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** @@ -172,12 +172,16 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } + /** Java-friendly version of [[chiSqTest()]] */ + @Since("1.5.0") + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + /** * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a * continuous distribution. By comparing the largest difference between the empirical cumulative @@ -191,6 +195,7 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double) : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, cdf) @@ -207,9 +212,20 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") @varargs def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*) : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @Since("1.5.0") + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: Double*): KolmogorovSmirnovTestResult = { + kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 9aa7763d7890..bd4d81390bfa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.distribution import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} -import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils @@ -32,8 +32,8 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution - * @since 1.3.0 */ +@Since("1.3.0") @DeveloperApi class MultivariateGaussian ( val mu: Vector, @@ -62,15 +62,15 @@ class MultivariateGaussian ( private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index cecd1fed896d..972841015d4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo @@ -44,6 +44,7 @@ import org.apache.spark.util.random.XORShiftRandom * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. */ +@Since("1.0.0") @Experimental class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { @@ -54,6 +55,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) @@ -62,6 +64,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } +@Since("1.0.0") object DecisionTree extends Serializable with Logging { /** @@ -79,7 +82,8 @@ object DecisionTree extends Serializable with Logging { * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction - */ + */ + @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).run(input) } @@ -101,6 +105,7 @@ object DecisionTree extends Serializable with Logging { * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -128,6 +133,7 @@ object DecisionTree extends Serializable with Logging { * @param numClasses number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -161,6 +167,7 @@ object DecisionTree extends Serializable with Logging { * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -193,6 +200,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -208,6 +216,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] */ + @Since("1.1.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -237,6 +246,7 @@ object DecisionTree extends Serializable with Logging { * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction */ + @Since("1.1.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -250,6 +260,7 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] */ + @Since("1.1.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 9ce6faa137c4..e750408600c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param boostingStrategy Parameters for the gradient boosting algorithm. */ +@Since("1.2.0") @Experimental class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { @@ -58,6 +59,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { @@ -75,6 +77,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. */ + @Since("1.2.0") def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } @@ -89,6 +92,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.4.0") def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -112,6 +116,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. */ + @Since("1.4.0") def runWithValidation( input: JavaRDD[LabeledPoint], validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -119,6 +124,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) } } +@Since("1.2.0") object GradientBoostedTrees extends Logging { /** @@ -130,6 +136,7 @@ object GradientBoostedTrees extends Logging { * @param boostingStrategy Configuration options for the boosting algorithm. * @return a gradient boosted trees model that can be used for prediction */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -139,6 +146,7 @@ object GradientBoostedTrees extends Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] */ + @Since("1.2.0") def train( input: JavaRDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 069959976a18..63a902f3eb51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy @@ -260,6 +260,7 @@ private class RandomForest ( } +@Since("1.2.0") object RandomForest extends Serializable with Logging { /** @@ -277,6 +278,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, @@ -314,6 +316,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -333,6 +336,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] */ + @Since("1.2.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -363,6 +367,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, @@ -399,6 +404,7 @@ object RandomForest extends Serializable with Logging { * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -417,6 +423,7 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] */ + @Since("1.2.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -434,6 +441,7 @@ object RandomForest extends Serializable with Logging { /** * List of supported feature subset sampling strategies. */ + @Since("1.2.0") val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "sqrt", "log2", "onethird") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index b6099259971b..8301ad160836 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to select the algorithm for the decision tree */ +@Since("1.0.0") @Experimental object Algo extends Enumeration { type Algo = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 50fe2ac53da9..7c569981977b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} @@ -39,6 +39,7 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * then stop. Ignored when * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ +@Since("1.2.0") @Experimental case class BoostingStrategy( // Required boosting parameters @@ -70,6 +71,7 @@ case class BoostingStrategy( } } +@Since("1.2.0") @Experimental object BoostingStrategy { @@ -78,6 +80,7 @@ object BoostingStrategy { * @param algo Learning goal. Supported: "Classification" or "Regression" * @return Configuration for boosting algorithm */ + @Since("1.2.0") def defaultParams(algo: String): BoostingStrategy = { defaultParams(Algo.fromString(algo)) } @@ -89,6 +92,7 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ + @Since("1.3.0") def defaultParams(algo: Algo): BoostingStrategy = { val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index f4c877232750..bb7c7ee4f964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to describe whether a feature is "continuous" or "categorical" */ +@Since("1.0.0") @Experimental object FeatureType extends Enumeration { type FeatureType = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 7da976e55a72..904e42deebb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum for selecting the quantile calculation strategy */ +@Since("1.0.0") @Experimental object QuantileStrategy extends Enumeration { type QuantileStrategy = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index de2c78480944..b74e3f1f4652 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -67,6 +67,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * the checkpoint directory is not set in * [[org.apache.spark.SparkContext]], this setting is ignored. */ +@Since("1.0.0") @Experimental class Strategy ( @BeanProperty var algo: Algo, @@ -83,10 +84,16 @@ class Strategy ( @BeanProperty var useNodeIdCache: Boolean = false, @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + /** + */ + @Since("1.2.0") def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 } + /** + */ + @Since("1.2.0") def isMulticlassWithCategoricalFeatures: Boolean = { isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } @@ -94,6 +101,7 @@ class Strategy ( /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ + @Since("1.1.0") def this( algo: Algo, impurity: Impurity, @@ -108,6 +116,7 @@ class Strategy ( /** * Sets Algorithm using a String. */ + @Since("1.2.0") def setAlgo(algo: String): Unit = algo match { case "Classification" => setAlgo(Classification) case "Regression" => setAlgo(Regression) @@ -116,6 +125,7 @@ class Strategy ( /** * Sets categoricalFeaturesInfo using a Java Map. */ + @Since("1.2.0") def setCategoricalFeaturesInfo( categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { this.categoricalFeaturesInfo = @@ -148,11 +158,6 @@ class Strategy ( s" Valid values are integers >= 0.") require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + s" Valid values are integers >= 2.") - categoricalFeaturesInfo.foreach { case (feature, arity) => - require(arity >= 2, - s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + - s" feature $feature has $arity categories. The number of categories should be >= 2.") - } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, @@ -162,7 +167,10 @@ class Strategy ( s"$subsamplingRate") } - /** Returns a shallow copy of this instance. */ + /** + * Returns a shallow copy of this instance. + */ + @Since("1.2.0") def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, @@ -170,6 +178,7 @@ class Strategy ( } } +@Since("1.2.0") @Experimental object Strategy { @@ -177,6 +186,7 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ + @Since("1.2.0") def defaultStrategy(algo: String): Strategy = { defaultStrategy(Algo.fromString(algo)) } @@ -185,6 +195,7 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo Algo.Classification or Algo.Regression */ + @Since("1.3.0") def defaultStrategy(algo: Algo): Strategy = algo match { case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 9fe264656ede..21ee49c45788 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0768204c3391..73df6b054a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. */ +@Since("1.0.0") @Experimental object Entropy extends Impurity { @@ -36,6 +37,7 @@ object Entropy extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -63,6 +65,7 @@ object Entropy extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") @@ -71,6 +74,7 @@ object Entropy extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d0077db6832e..f21845b21a80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -25,6 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] * during binary classification. */ +@Since("1.0.0") @Experimental object Gini extends Impurity { @@ -35,6 +36,7 @@ object Gini extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -59,6 +61,7 @@ object Gini extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") @@ -67,6 +70,7 @@ object Gini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 86cee7e430b0..4637dcceea7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -26,6 +26,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] * (b) calculating impurity values from sufficient statistics. */ +@Since("1.0.0") @Experimental trait Impurity extends Serializable { @@ -36,6 +37,7 @@ trait Impurity extends Serializable { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -47,6 +49,7 @@ trait Impurity extends Serializable { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 04d0cd24e663..a74197278d6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating variance during regression */ +@Since("1.0.0") @Experimental object Variance extends Impurity { @@ -33,6 +34,7 @@ object Variance extends Impurity { * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") @@ -45,6 +47,7 @@ object Variance extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { if (count == 0) { @@ -58,6 +61,7 @@ object Variance extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ + @Since("1.0.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 2bdef73c4a8f..bab7b8c6cadf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * |y - F(x)| * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object AbsoluteError extends Loss { @@ -41,6 +42,7 @@ object AbsoluteError extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { if (label - prediction < 0) 1.0 else -1.0 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 778c24526de7..b2b4594712f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * 2 log(1 + exp(-2 y F(x))) * where y is a label in {-1, 1} and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object LogLoss extends Loss { @@ -43,6 +44,7 @@ object LogLoss extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 64ffccbce073..687cde325ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. */ +@Since("1.2.0") @DeveloperApi trait Loss extends Serializable { @@ -36,6 +37,7 @@ trait Loss extends Serializable { * @param label true label. * @return Loss gradient. */ + @Since("1.2.0") def gradient(prediction: Double, label: Double): Double /** @@ -46,6 +48,7 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data */ + @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map(point => computeError(model.predict(point.features), point.label)).mean() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala index 42c9ead9884b..2b112fbe1220 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.tree.loss +import org.apache.spark.annotation.Since + +@Since("1.2.0") object Losses { + @Since("1.2.0") def fromString(name: String): Loss = name match { case "leastSquaresError" => SquaredError case "leastAbsoluteError" => AbsoluteError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 011a5d57422f..3f7d3d38be16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * (y - F(x))**2 * where y is the label and F(x) is the model prediction for features x. */ +@Since("1.2.0") @DeveloperApi object SquaredError extends Loss { @@ -41,6 +42,7 @@ object SquaredError extends Loss { * @param label True label. * @return Loss gradient */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 2.0 * (label - prediction) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index f2c78bbabff0..3eefd135f783 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} @@ -41,6 +41,7 @@ import org.apache.spark.util.Utils * @param topNode root node * @param algo algorithm type -- classification or regression */ +@Since("1.0.0") @Experimental class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { @@ -50,6 +51,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features array representing a single data point * @return Double prediction from the trained model */ + @Since("1.0.0") def predict(features: Vector): Double = { topNode.predict(features) } @@ -60,6 +62,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features RDD representing data points to be predicted * @return RDD of predictions for each of the given data points */ + @Since("1.0.0") def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } @@ -70,6 +73,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param features JavaRDD representing data points to be predicted * @return JavaRDD of predictions for each of the given data points */ + @Since("1.2.0") def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { predict(features.rdd) } @@ -77,6 +81,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Get number of nodes in tree, including leaf nodes. */ + @Since("1.1.0") def numNodes: Int = { 1 + topNode.numDescendants } @@ -85,6 +90,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Get depth of tree. * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. */ + @Since("1.1.0") def depth: Int = { topNode.subtreeDepth } @@ -109,6 +115,12 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable header + topNode.subtreeToString(2) } + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } @@ -116,6 +128,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable override protected def formatVersion: String = DecisionTreeModel.formatVersion } +@Since("1.3.0") object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { private[spark] def formatVersion: String = "1.0" @@ -297,6 +310,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } } + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): DecisionTreeModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 508bf9c1bdb4..091a0462c204 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator * @param leftPredict left node predict * @param rightPredict right node predict */ +@Since("1.0.0") @DeveloperApi class InformationGainStats( val gain: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index a6d1398fc267..8c54c5510723 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vector @@ -39,6 +39,7 @@ import org.apache.spark.mllib.linalg.Vector * @param rightNode right child * @param stats information gain stats */ +@Since("1.0.0") @DeveloperApi class Node ( val id: Int, @@ -59,6 +60,7 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @Since("1.0.0") @deprecated("build should no longer be used since trees are constructed on-the-fly in training", "1.2.0") def build(nodes: Array[Node]): Unit = { @@ -80,6 +82,7 @@ class Node ( * @param features feature value * @return predicted value */ + @Since("1.1.0") def predict(features: Vector) : Double = { if (isLeaf) { predict.predict diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 5cbe7c280dbe..965784051ede 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ +@Since("1.2.0") @DeveloperApi class Predict( val predict: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index be6c9b3de547..45db83ae3a1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * @param featureType type of feature -- categorical or continuous * @param categories Split left if categorical feature value is in this set, else right. */ +@Since("1.0.0") @DeveloperApi case class Split( feature: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 905c5fb42bd4..19571447a2c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -46,6 +46,7 @@ import org.apache.spark.util.Utils * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles */ +@Since("1.2.0") @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), @@ -54,6 +55,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis require(trees.forall(_.algo == algo)) + /** + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, RandomForestModel.SaveLoadV1_0.thisClassName) @@ -62,10 +70,18 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis override protected def formatVersion: String = RandomForestModel.formatVersion } +@Since("1.3.0") object RandomForestModel extends Loader[RandomForestModel] { private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + /** + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -97,6 +113,7 @@ object RandomForestModel extends Loader[RandomForestModel] { * @param trees tree ensembles * @param treeWeights tree ensemble weights */ +@Since("1.2.0") @Experimental class GradientBoostedTreesModel( override val algo: Algo, @@ -107,6 +124,12 @@ class GradientBoostedTreesModel( require(trees.length == treeWeights.length) + /** + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) @@ -119,6 +142,7 @@ class GradientBoostedTreesModel( * @return an array with index i having the losses or errors for the ensemble * containing the first i+1 trees */ + @Since("1.4.0") def evaluateEachIteration( data: RDD[LabeledPoint], loss: Loss): Array[Double] = { @@ -159,6 +183,9 @@ class GradientBoostedTreesModel( override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion } +/** + */ +@Since("1.3.0") object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** @@ -171,6 +198,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @return a RDD with each element being a zip of the prediction and error * corresponding to every sample. */ + @Since("1.4.0") def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -194,6 +222,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @return a RDD with each element being a zip of the prediction and error * corresponding to each sample. */ + @Since("1.4.0") def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], @@ -213,6 +242,12 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + /** + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84c..4940974bf4f4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD @@ -65,6 +65,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, @@ -114,6 +115,7 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -127,12 +129,14 @@ object MLUtils { * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -141,6 +145,7 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -152,6 +157,7 @@ object MLUtils { * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. */ + @Since("1.0.0") def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -182,12 +188,14 @@ object MLUtils { * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -198,6 +206,7 @@ object MLUtils { * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -205,6 +214,7 @@ object MLUtils { * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -221,6 +231,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => @@ -242,6 +253,7 @@ object MLUtils { * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ + @Since("1.0.0") @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) @@ -254,6 +266,7 @@ object MLUtils { * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. */ + @Since("1.0.0") @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat @@ -269,6 +282,7 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ + @Since("1.0.0") def appendBias(vector: Vector): Vector = { vector match { case dv: DenseVector => diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e9aa383728f..618b95b9bd12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -100,9 +100,7 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {1.0, 0.0}; - model.transform( - dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -112,9 +110,8 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java new file mode 100644 index 000000000000..ec6b4bf3c0f8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaMultilayerPerceptronClassifierSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + sqlContext = null; + } + + @Test + public void testMLPC() { + DataFrame dataFrame = sqlContext.createDataFrame( + jsc.parallelize(Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), + LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() + .setLayers(new int[] {2, 5, 2}) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100); + MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); + DataFrame result = model.transform(dataFrame); + Row[] predictionAndLabels = result.select("prediction", "label").collect(); + for (Row r: predictionAndLabels) { + Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java new file mode 100644 index 000000000000..56988b9fb29c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; + + +public class JavaVectorSlicerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void vectorSlice() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + + DataFrame output = vectorSlicer.transform(dataset); + + for (Row r : output.select("userFeatures", "features").take(2)) { + Vector features = r.getAs(1); + Assert.assertEquals(features.size(), 2); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c8576..3fea359a3b46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -22,12 +22,14 @@ import java.util.Arrays; import scala.Tuple2; +import scala.Tuple3; import org.junit.After; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertArrayEquals; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; @@ -44,9 +46,9 @@ public class JavaLDASuite implements Serializable { public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); ArrayList> tinyCorpus = new ArrayList>(); - for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), - LDASuite$.MODULE$.tinyCorpus()[i]._2())); + for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); @@ -60,7 +62,7 @@ public void tearDown() { @Test public void localLDAModel() { - Matrix topics = LDASuite$.MODULE$.tinyTopics(); + Matrix topics = LDASuite.tinyTopics(); double[] topicConcentration = new double[topics.numRows()]; Arrays.fill(topicConcentration, 1.0D / topics.numRows()); LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); @@ -110,8 +112,8 @@ public void distributedLDAModel() { assertEquals(roundedLocalTopicSummary.length, k); // Check: log probabilities - assert(model.logLikelihood() < 0.0); - assert(model.logPrior() < 0.0); + assertTrue(model.logLikelihood() < 0.0); + assertTrue(model.logPrior() < 0.0); // Check: topic distributions JavaPairRDD topicDistributions = model.javaTopicDistributions(); @@ -124,6 +126,21 @@ public Boolean call(Tuple2 tuple2) { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first(); + Long docId = topTopics._1(); // confirm doc ID type + int[] topicIndices = topTopics._2(); + double[] topicWeights = topTopics._3(); + assertEquals(3, topicIndices.length); + assertEquals(3, topicWeights.length); + + // Check: topTopicAssignments + Tuple3 topicAssignment = model.javaTopicAssignments().first(); + Long docId2 = topicAssignment._1(); + int[] termIndices2 = topicAssignment._2(); + int[] topicIndices2 = topicAssignment._3(); + assertEquals(termIndices2.length, topicIndices2.length); } @Test @@ -160,11 +177,31 @@ public void OnlineOptimizerCompatibility() { assertEquals(roundedLocalTopicSummary.length, k); } - private static int tinyK = LDASuite$.MODULE$.tinyK(); - private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); - private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + + private static int tinyK = LDASuite.tinyK(); + private static int tinyVocabSize = LDASuite.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite$.MODULE$.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite.toyModel(); + private ArrayList> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index b3815ae6039c..d7c2cb3ae206 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -49,7 +49,7 @@ public void runAssociationRules() { JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList( new FreqItemset(new String[] {"a"}, 15L), new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 18L) + new FreqItemset(new String[] {"a", "b"}, 12L) )); JavaRDD> results = (new AssociationRules()).run(freqItemsets); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c98..eb4e3698624b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -27,7 +27,12 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; @@ -53,4 +58,21 @@ public void testCorr() { // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 63d2fa31c749..1f2c9b75b617 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.sql.DataFrame class PipelineSuite extends SparkFunSuite { @@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite { .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) + MLTestingUtils.checkCopy(pipelineModel) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c7bbf1ce07a2..f680d8d3c4cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -244,6 +245,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) .collect() @@ -257,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index d4b5896c12c0..e3909bccaa5c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setCheckpointInterval(2) val model = gbt.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + sc.checkpointDir = None Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8c3d4590f5ae..cce39f382f73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -94,12 +95,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("setThreshold, getThreshold") { val lr = new LogisticRegression // default - withClue("LogisticRegression should not have thresholds set by default") { - intercept[java.util.NoSuchElementException] { + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future lr.getThresholds } } - // Set via thresholds. + // Set via threshold. // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. lr.setThreshold(1.0) assert(lr.getThresholds === Array(0.0, 1.0)) @@ -107,10 +109,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getThresholds === Array(1.0, 0.0)) lr.setThreshold(0.5) assert(lr.getThresholds === Array(0.5, 0.5)) - // Test getThreshold - lr.setThresholds(Array(0.3, 0.7)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) - assert(lr.getThreshold ~== expectedThreshold relTol 1E-7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -118,6 +136,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { lr.setFitIntercept(false) val model = lr.fit(dataset) assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { @@ -145,7 +166,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(1.0, 0.0), + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() @@ -153,8 +174,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. + lr.setThresholds(Array(0.6, 0.4)) val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, - lr.thresholds -> Array(0.6, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index bd8e819f6926..977f0e0b70c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + assert(ovaModel.models.size === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6ca4b5aa5fde..b4403ec30049 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 1f15ac02f400..688b0e31f91d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,10 +52,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getFeaturesCol === "features") assert(kmeans.getPredictionCol === "prediction") assert(kmeans.getMaxIter === 20) - assert(kmeans.getRuns === 1) assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 5) - assert(kmeans.getEpsilon === 1e-4) + assert(kmeans.getTol === 1e-4) } test("set parameters") { @@ -64,21 +63,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setFeaturesCol("test_feature") .setPredictionCol("test_prediction") .setMaxIter(33) - .setRuns(7) .setInitMode(MLlibKMeans.RANDOM) .setInitSteps(3) .setSeed(123) - .setEpsilon(1e-3) + .setTol(1e-3) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") assert(kmeans.getPredictionCol === "test_prediction") assert(kmeans.getMaxIter === 33) - assert(kmeans.getRuns === 7) assert(kmeans.getInitMode === MLlibKMeans.RANDOM) assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) - assert(kmeans.getEpsilon === 1e-3) + assert(kmeans.getTol === 1e-3) } test("parameters validation") { @@ -91,9 +88,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } - intercept[IllegalArgumentException] { - new KMeans().setRuns(0) - } } test("fit & transform") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 5b203784559e..aa722da32393 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) // r2 score evaluator.setMetricName("r2") @@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ec85e0d151e0..0eba34fda622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala new file mode 100644 index 000000000000..e192fa4850af --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) + } + + private def split(s: String): Seq[String] = s.split("\\s+") + + test("CountVectorizerModel common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, split("a b b c d a"), + Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split(""), Vectors.sparse(4, Seq())), // empty string + (4, split("a notInDict d"), + Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary + )).toDF("id", "words", "expected") + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer common cases") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d e"), + Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), + (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), + (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), + (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + ).toDF("id", "words", "expected") + val cv = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .fit(df) + assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer vocabSize and minDF") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq((0, 1.0))))) + ).toDF("id", "words", "expected") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .fit(df) + assert(cvModel.vocabulary === Array("a", "b", "c")) + + // minDF: ignore terms with count less than 3 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3) + .fit(df) + assert(cvModel2.vocabulary === Array("a", "b")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // minDF: ignore terms with freq < 0.75 + val cvModel3 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(3.0 / df.count()) + .fit(df) + assert(cvModel3.vocabulary === Array("a", "b")) + + cvModel3.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer throws exception when vocab is empty") { + intercept[IllegalArgumentException] { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a b b c c")), + (1, split("aa bb cc"))) + ).toDF("id", "words") + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) // limit vocab size to 3 + .setMinDF(3) + .fit(df) + } + } + + test("CountVectorizerModel with minTF count") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq())), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizerModel with minTF freq") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), + (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), + (3, split("e e e e e"), Vectors.sparse(4, Seq()))) + ).toDF("id", "words", "expected") + + // minTF: count + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setMinTF(0.3) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala deleted file mode 100644 index e90d9d4ef21f..000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.ml.feature - -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ - -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("params") { - ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) - } - - test("CountVectorizerModel common cases") { - val df = sqlContext.createDataFrame(Seq( - (0, "a b c d".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), - (1, "a b b c d a".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))), - (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string - (4, "a notInDict d".split(" ").toSeq, - Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary - )).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) - .setInputCol("words") - .setOutputCol("features") - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) - } - } - - test("CountVectorizerModel with minTermFreq") { - val df = sqlContext.createDataFrame(Seq( - (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), - (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))), - (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())), - (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq()))) - ).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) - .setInputCol("words") - .setOutputCol("features") - .setMinTermFreq(3) - val output = cv.transform(df).collect() - output.foreach { p => - val features = p.getAs[Vector]("features") - val expected = p.getAs[Vector]("expected") - assert(features ~== expected absTol 1e-14) - } - } -} - - diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c452054bec92..c04dda41eea3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} @@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { .foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1.equals(vector2), "Transformed vector is different with expected.") } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index d0ae36b28c7a..30c500f87a76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { .setK(3) .fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + pca.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 000000000000..d19052881ae4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0295a0fe2fc..05e05bdc64bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") { @@ -37,6 +42,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] @@ -47,19 +56,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) + } + + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) } test("StringIndexer with a numeric input column") { @@ -88,4 +115,54 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } + + test("StringIndexer, IndexToString are inverses") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val idx2str = new IndexToString() + .setInputCol("labelIndex") + .setOutputCol("sameLabel") + .setLabels(indexer.labels) + idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { + case Row(a: String, b: String) => + assert(a === b) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 03120c828ca9..8cb0a2cf14d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L test("Throws error when given RDDs with different size vectors") { val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work intercept[SparkException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index adcda0e623b2..a2e46f202995 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -62,6 +63,9 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea01..2c878f8372a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -199,6 +199,17 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) + + val arrayLengthGt = ParamValidators.arrayLengthGt[Int](2.0) + assert(arrayLengthGt(Array(0, 1, 2)) && !arrayLengthGt(Array(0, 1))) + } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb..eadc80e0e62b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 33aa9d0d6234..b092bcd6a7e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dbdce0c9dea5..a68197b59193 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxDepth(2) .setMaxIter(2) val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) val preds = model.transform(df) val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 21ad8225bd9f..2aaee71ecc73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -72,6 +73,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getFitIntercept) assert(lir.getStandardization) val model = lir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "prediction") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 992ce9562434..7b1b3f11481d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -91,7 +92,11 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) - val importances = rf.fit(df).featureImportances + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a7605..fde02e0c84bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} @@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) @@ -138,6 +143,8 @@ object CrossValidatorSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index c8e58f216cce..ef24e6fb6b80 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -132,6 +132,8 @@ object TrainValidationSplitSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 000000000000..d290cc9b06e7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index fdc2554ab853..8a714f9b79e0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.{ArrayList => JArrayList} + import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite @@ -133,17 +135,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } // Top 3 documents per topic - model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } // All documents per topic val q = tinyCorpus.length - model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } + + // Check: topTopicAssignments + // Make sure it assigns a topic to each term appearing in each doc. + val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] = + model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap + assert(topTopicAssignments.keys.max < tinyCorpus.length) + tinyCorpus.foreach { case (docID: Long, doc: Vector) => + if (topTopicAssignments.contains(docID)) { + val (inds, vals) = topTopicAssignments(docID) + assert(inds.length === doc.numNonzeros) + // For "term" in actual doc, + // check that it has a topic assigned. + doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term))) + } else { + assert(doc.numNonzeros === 0) + } + } } test("vertex indexing") { @@ -160,8 +179,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha.toArray.forall(_ === 2.0)) - assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } @@ -575,6 +594,17 @@ private[clustering] object LDASuite { Vectors.sparse(6, Array(4, 5), Array(1, 1)) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.length) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + def toyModel: LocalLDAModel = { val k = 2 val vocabSize = 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a270ba2562db..bfd6d5495f5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite { } } + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 532463e96fbb..3d2edf9d9451 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -43,6 +43,22 @@ ${project.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index db9dc4f17cee..0df1dd621f6e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -17,11 +17,12 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.IOException; import java.util.List; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; -import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,10 +32,10 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; + /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -46,11 +47,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); - private final ExternalShuffleBlockResolver blockManager; + @VisibleForTesting + final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf) { - this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf)); + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + this(new OneForOneStreamManager(), + new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } /** Enables mocking out the StreamManager and BlockManager. */ @@ -105,4 +108,22 @@ public StreamManager getStreamManager() { public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + + /** + * Register an (application, executor) with the given shuffle info. + * + * The "re-" is meant to highlight the intended use of this method -- when this service is + * restarted, this is used to restore the state of executors from before the restart. Normal + * registration will happen via a message handled in receive() + * + * @param appExecId + * @param executorInfo + */ + public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) { + blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo); + } + + public void close() { + blockManager.close(); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 022ed88a1648..79beec4429a9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -17,19 +17,24 @@ package org.apache.spark.network.shuffle; -import java.io.DataInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Iterator; -import java.util.Map; +import java.io.*; +import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import com.google.common.collect.Maps; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; +import org.iq80.leveldb.Options; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,25 +57,87 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); + private static final ObjectMapper mapper = new ObjectMapper(); + /** + * This a common prefix to the key for each app registration we stick in leveldb, so they + * are easy to find, since leveldb lets you search based on prefix. + */ + private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; + private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + // Map containing all registered executors' metadata. - private final ConcurrentMap executors; + @VisibleForTesting + final ConcurrentMap executors; // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; private final TransportConf conf; - public ExternalShuffleBlockResolver(TransportConf conf) { - this(conf, Executors.newSingleThreadExecutor( + @VisibleForTesting + final File registeredExecutorFile; + @VisibleForTesting + final DB db; + + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) + throws IOException { + this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( // Add `spark` prefix because it will run in NM in Yarn mode. NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); } // Allows tests to have more control over when directories are cleaned up. @VisibleForTesting - ExternalShuffleBlockResolver(TransportConf conf, Executor directoryCleaner) { + ExternalShuffleBlockResolver( + TransportConf conf, + File registeredExecutorFile, + Executor directoryCleaner) throws IOException { this.conf = conf; - this.executors = Maps.newConcurrentMap(); + this.registeredExecutorFile = registeredExecutorFile; + if (registeredExecutorFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + DB tmpDb; + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + registeredExecutorFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", registeredExecutorFile, e); + if (registeredExecutorFile.isDirectory()) { + for (File f : registeredExecutorFile.listFiles()) { + f.delete(); + } + } + registeredExecutorFile.delete(); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb); + executors = reloadRegisteredExecutors(tmpDb); + db = tmpDb; + } else { + db = null; + executors = Maps.newConcurrentMap(); + } this.directoryCleaner = directoryCleaner; } @@ -81,6 +148,15 @@ public void registerExecutor( ExecutorShuffleInfo executorInfo) { AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); + try { + if (db != null) { + byte[] key = dbAppExecKey(fullId); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + db.put(key, value); + } + } catch (Exception e) { + logger.error("Error saving registered executors", e); + } executors.put(fullId, executorInfo); } @@ -136,6 +212,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { // Only touch executors associated with the appId that was removed. if (appId.equals(fullId.appId)) { it.remove(); + if (db != null) { + try { + db.delete(dbAppExecKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } if (cleanupLocalDirs) { logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); @@ -220,12 +303,23 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } + void close() { + if (db != null) { + try { + db.close(); + } catch (IOException e) { + logger.error("Exception closing leveldb with registered executors", e); + } + } + } + /** Simply encodes an executor's full ID, which is appId + execId. */ - private static class AppExecId { - final String appId; - final String execId; + public static class AppExecId { + public final String appId; + public final String execId; - private AppExecId(String appId, String execId) { + @JsonCreator + public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) { this.appId = appId; this.execId = execId; } @@ -252,4 +346,105 @@ public String toString() { .toString(); } } + + private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(Charsets.UTF_8); + } + + private static AppExecId parseDbAppExecKey(String s) throws IOException { + if (!s.startsWith(APP_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX); + } + String json = s.substring(APP_KEY_PREFIX.length() + 1); + AppExecId parsed = mapper.readValue(json, AppExecId.class); + return parsed; + } + + @VisibleForTesting + static ConcurrentMap reloadRegisteredExecutors(DB db) + throws IOException { + ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); + if (db != null) { + DBIterator itr = db.iterator(); + itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), Charsets.UTF_8); + if (!key.startsWith(APP_KEY_PREFIX)) { + break; + } + AppExecId id = parseDbAppExecKey(key); + ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); + registeredExecutors.put(id, shuffleInfo); + } + } + return registeredExecutors; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + private static void checkVersion(DB db) throws IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != CURRENT_VERSION.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + CURRENT_VERSION); + } + storeVersion(db); + } + } + + private static void storeVersion(DB db) throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); + } + + + public static class StoreVersion { + + final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index cadc8e8369c6..102d4efb8bf3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable { /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ public final String shuffleManager; - public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + @JsonCreator + public ExecutorShuffleInfo( + @JsonProperty("localDirs") String[] localDirs, + @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, + @JsonProperty("shuffleManager") String shuffleManager) { this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 1c28fc1dff24..94a61d6caadc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -23,6 +23,9 @@ import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * A message sent from the driver to register with the MesosExternalShuffleService. */ diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d02f4f0fdb68..3c6cb367dea4 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -21,9 +21,12 @@ import java.io.InputStream; import java.io.InputStreamReader; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -59,8 +62,8 @@ public static void afterAll() { } @Test - public void testBadRequests() { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + public void testBadRequests() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); @@ -91,7 +94,7 @@ public void testBadRequests() { @Test public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); @@ -110,7 +113,7 @@ public void testSortShuffleBlocks() throws IOException { @Test public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); @@ -126,4 +129,28 @@ public void testHashShuffleBlocks() throws IOException { block1Stream.close(); assertEquals(hashBlock1, block1); } + + @Test + public void jsonSerializationOfExecutorRegistration() throws IOException { + ObjectMapper mapper = new ObjectMapper(); + AppExecId appId = new AppExecId("foo", "bar"); + String appIdJson = mapper.writeValueAsString(appId); + AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); + assertEquals(parsedAppId, appId); + + ExecutorShuffleInfo shuffleInfo = + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + String shuffleJson = mapper.writeValueAsString(shuffleInfo); + ExecutorShuffleInfo parsedShuffleInfo = + mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); + assertEquals(parsedShuffleInfo, shuffleInfo); + + // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // its not legacy yet, but keeping this here in case anybody changes it + String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; + assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); + String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); + } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index d9d9c1bf2f17..2f4f1d0df478 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -42,7 +42,7 @@ public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); resolver.applicationRemoved("app", false /* cleanup */); @@ -65,7 +65,8 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, noThreadExecutor); + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -83,7 +84,7 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); @@ -99,7 +100,7 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 39aa49911d9c..a3f9a38b1aeb 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException { dataContext1.insertHashShuffleData(1, 0, exec1Blocks); conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(conf); + handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index d4ec1956c1e2..aa99efda9494 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -43,8 +43,9 @@ public class ExternalShuffleSecuritySuite { TransportServer server; @Before - public void beforeEach() { - TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf)); + public void beforeEach() throws IOException { + TransportContext context = + new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, new TestSecretKeyHolder("my-app-id", "secret")); this.server = context.createServer(Arrays.asList(bootstrap)); diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 463f99ef3352..11ea7f3fd3cf 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,25 +17,21 @@ package org.apache.spark.network.yarn; +import java.io.File; import java.nio.ByteBuffer; import java.util.List; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.server.api.AuxiliaryService; -import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; -import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; -import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; -import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.apache.hadoop.yarn.server.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; -import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; @@ -79,11 +75,26 @@ public class YarnShuffleService extends AuxiliaryService { private TransportServer shuffleServer = null; // Handles registering executors and opening shuffle blocks - private ExternalShuffleBlockHandler blockHandler; + @VisibleForTesting + ExternalShuffleBlockHandler blockHandler; + + // Where to store & reload executor info for recovering state after an NM restart + @VisibleForTesting + File registeredExecutorFile; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); + instance = this; } /** @@ -100,11 +111,24 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = + findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - blockHandler = new ExternalShuffleBlockHandler(transportConf); + try { + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + } catch (Exception e) { + logger.error("Failed to initialize external shuffle service", e); + } List bootstraps = Lists.newArrayList(); if (authEnabled) { @@ -116,9 +140,13 @@ protected void serviceInit(Configuration conf) { SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); TransportContext transportContext = new TransportContext(transportConf, blockHandler); shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}.", port, authEnabledString); + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } @Override @@ -161,6 +189,16 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } + private File findRegisteredExecutorFile(String[] localDirs) { + for (String dir: localDirs) { + File f = new File(dir, "registeredExecutors.ldb"); + if (f.exists()) { + return f; + } + } + return new File(localDirs[0], "registeredExecutors.ldb"); + } + /** * Close the shuffle server to clean up any associated state. */ @@ -170,6 +208,9 @@ protected void serviceStop() { if (shuffleServer != null) { shuffleServer.close(); } + if (blockHandler != null) { + blockHandler.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -180,5 +221,4 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } - } diff --git a/pom.xml b/pom.xml index 2bcc55b040a2..d5945f2546d3 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/flume-sink external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl @@ -431,7 +432,7 @@ ${commons-lang3.version} - org.apache.commons + commons-lang commons-lang ${commons-lang2.version} @@ -654,6 +655,11 @@ jackson-databind ${fasterxml.jackson.version} + + com.fasterxml.jackson.core + jackson-annotations + ${fasterxml.jackson.version} + @@ -1597,7 +1603,7 @@ com.twitter parquet-hadoop-bundle ${hive.parquet.version} - runtime + compile org.apache.flume @@ -1894,6 +1900,7 @@ ${project.build.directory}/tmp ${spark.test.home} 1 + false false false true diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b60ae784c379..88745dc086a0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,8 +62,6 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet"), // The old JSON RDD is removed in favor of streaming Jackson ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), @@ -155,11 +153,45 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") ) ++ Seq( // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") ) case v if v.startsWith("1.4") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce..04e0d49b178c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -42,11 +42,11 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", + "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -212,6 +212,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -316,6 +319,8 @@ object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -325,9 +330,14 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } @@ -337,8 +347,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -346,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -382,13 +391,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -540,6 +552,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", diff --git a/python/pyspark/context.py b/python/pyspark/context.py index eb5b0bbbdac4..1b2a52ad6411 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -302,10 +302,10 @@ def applicationId(self): """ A unique identifier for the Spark application. Its format depends on the scheduler implementation. - (i.e. - in case of local spark app something like 'local-1433865536131' - in case of YARN something like 'application_1433865536131_34483' - ) + + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + >>> sc.applicationId # doctest: +ELLIPSIS u'local-...' """ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5978d8f4d3a0..83f808efc3bf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -34,6 +34,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors @@ -75,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -96,11 +99,15 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification self.thresholds = \ Param(self, "thresholds", "Thresholds in multi-class classification" + @@ -109,29 +116,28 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True) + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ - # Under the hood we use thresholds so translate threshold to thresholds if applicable - if thresholds is None and threshold is not None: - kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -164,44 +170,65 @@ def getFitIntercept(self): def setThreshold(self, value): """ - Sets the value of :py:attr:`thresholds` using [1-value, value]. + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self - >>> lr = LogisticRegression() - >>> lr.getThreshold() - 0.5 - >>> lr.setThreshold(0.6) - LogisticRegression_... - >>> abs(lr.getThreshold() - 0.6) < 1e-5 - True + def getThreshold(self): + """ + Gets the value of threshold or its default value. """ - return self.setThresholds([1-value, value]) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. """ self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] return self def getThresholds(self): """ - Gets the value of thresholds or its default value. + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. """ - return self.getOrDefault(self.thresholds) + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - if self.isDefined(self.thresholds): - thresholds = self.getOrDefault(self.thresholds) - if len(thresholds) != 2: + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(thresholds)) - return 1.0/(1.0+thresholds[0]/thresholds[1]) - else: - return 0.5 + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): @@ -656,6 +683,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H HasRawPredictionCol): """ Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b5e9b6549d9f..cb4c16e25a7a 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,6 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector __all__ = ['KMeans', 'KMeansModel'] @@ -35,15 +34,17 @@ def clusterCenters(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): """ - K-means Clustering + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. >>> from pyspark.mllib.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = sqlContext.createDataFrame(data, ["features"]) - >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() >>> len(centers) @@ -58,10 +59,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") - epsilon = Param(Params._dummy(), "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + @@ -69,21 +66,21 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") @keyword_only - def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self.k = Param(self, "k", "number of clusters to create") - self.epsilon = Param(self, "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") - self.seed = Param(self, "seed", "random seed") self.initMode = Param(self, "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++") self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") - self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -91,9 +88,11 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only - def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ - setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) Sets params for KMeans. """ @@ -117,40 +116,6 @@ def getK(self): """ return self.getOrDefault(self.k) - def setEpsilon(self, value): - """ - Sets the value of :py:attr:`epsilon`. - - >>> algo = KMeans().setEpsilon(1e-5) - >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 - True - """ - self._paramMap[self.epsilon] = value - return self - - def getEpsilon(self): - """ - Gets the value of `epsilon` - """ - return self.getOrDefault(self.epsilon) - - def setRuns(self, value): - """ - Sets the value of :py:attr:`runs`. - - >>> algo = KMeans().setRuns(10) - >>> algo.getRuns() - 10 - """ - self._paramMap[self.runs] = value - return self - - def getRuns(self): - """ - Gets the value of `runs` - """ - return self.getOrDefault(self.runs) - def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 06e809352225..6b0a9ffde9f4 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -23,7 +23,8 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] @inherit_doc @@ -45,7 +46,7 @@ def _evaluate(self, dataset): """ raise NotImplementedError() - def evaluate(self, dataset, params={}): + def evaluate(self, dataset, params=None): """ Evaluates the output with optional parameters. @@ -55,6 +56,8 @@ def evaluate(self, dataset, params={}): params :return: metric """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params)._evaluate(dataset) @@ -160,11 +163,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): ... >>> evaluator = RegressionEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) - -2.842... + 2.842... >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) - -2.649... + 2.649... """ # Because we will maximize evaluation value (ref: `CrossValidator`), # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cb4dfa21298c..04b2b2ccc9e5 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,10 +26,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', - 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] +__all__ = ['Binarizer', 'Bucketizer', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', + 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', + 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', + 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc @@ -165,6 +166,63 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): + """ + Outputs the Hadamard product (i.e., the element-wise product) of each input vector + with a provided "weight" vector. In other words, it scales each column of the dataset + by a scalar multiplier. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) + >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]), + ... inputCol="values", outputCol="eprod") + >>> ep.transform(df).head().eprod + DenseVector([2.0, 2.0, 9.0]) + >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod + DenseVector([4.0, 3.0, 15.0]) + """ + + # a placeholder to make it appear in the generated doc + scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + + @keyword_only + def __init__(self, scalingVec=None, inputCol=None, outputCol=None): + """ + __init__(self, scalingVec=None, inputCol=None, outputCol=None) + """ + super(ElementwiseProduct, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", + self.uid) + self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + + "it must be MLlib Vector type.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, scalingVec=None, inputCol=None, outputCol=None): + """ + setParams(self, scalingVec=None, inputCol=None, outputCol=None) + Sets params for this ElementwiseProduct. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setScalingVec(self, value): + """ + Sets the value of :py:attr:`scalingVec`. + """ + self._paramMap[self.scalingVec] = value + return self + + def getScalingVec(self): + """ + Gets the value of scalingVec or its default value. + """ + return self.getOrDefault(self.scalingVec) + + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 7845536161e0..eeeac49b2198 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -60,14 +60,16 @@ class Params(Identifiable): __metaclass__ = ABCMeta - #: internal param map for user-supplied values param map - _paramMap = {} + def __init__(self): + super(Params, self).__init__() + #: internal param map for user-supplied values param map + self._paramMap = {} - #: internal param map for default values - _defaultParamMap = {} + #: internal param map for default values + self._defaultParamMap = {} - #: value returned by :py:func:`params` - _params = None + #: value returned by :py:func:`params` + self._params = None @property def params(self): @@ -155,7 +157,7 @@ def getOrDefault(self, param): else: return self._defaultParamMap[param] - def extractParamMap(self, extra={}): + def extractParamMap(self, extra=None): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into @@ -165,12 +167,14 @@ def extractParamMap(self, extra={}): :param extra: extra param values :return: merged param map """ + if extra is None: + extra = dict() paramMap = self._defaultParamMap.copy() paramMap.update(self._paramMap) paramMap.update(extra) return paramMap - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -181,6 +185,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = copy.copy(self) that._paramMap = self.extractParamMap(extra) return that @@ -233,7 +239,7 @@ def _setDefault(self, **kwargs): self._defaultParamMap[getattr(self, param)] = value return self - def _copyValues(self, to, extra={}): + def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. @@ -241,6 +247,8 @@ def _copyValues(self, to, extra={}): :param extra: extra params to be copied :return: the target instance with param values copied """ + if extra is None: + extra = dict() paramMap = self.extractParamMap(extra) for p in self.params: if p in paramMap and to.hasParam(p.name): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9889f56cac9e..13cf2b0f7bbd 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -141,7 +141,7 @@ class Pipeline(Estimator): @keyword_only def __init__(self, stages=None): """ - __init__(self, stages=[]) + __init__(self, stages=None) """ if stages is None: stages = [] @@ -170,7 +170,7 @@ def getStages(self): @keyword_only def setParams(self, stages=None): """ - setParams(self, stages=[]) + setParams(self, stages=None) Sets params for Pipeline. """ if stages is None: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f1..dcfee6a3170a 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -227,7 +227,9 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) @@ -250,7 +252,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, @@ -259,6 +261,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5b7afc15ddfb..41946e3674fb 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -207,8 +207,10 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/(2n) ||A weights - y||^2, + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -334,7 +336,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -451,7 +455,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 3f5a02af12e3..5097c5e8ba4c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -32,6 +32,9 @@ from py4j.protocol import Py4JJavaError +if sys.version > '3': + basestring = str + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -86,9 +89,42 @@ def tearDown(self): self.ssc.stop(False) @staticmethod - def _ssc_wait(start_time, end_time, sleep_time): - while time() - start_time < end_time: + def _eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for streaming ML tests. + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) def _squared_distance(a, b): @@ -999,10 +1035,13 @@ def test_accuracy_for_single_center(self): [self.sc.parallelize(batch, 1) for batch in batches]) stkm.trainOn(input_stream) - t = time() self.ssc.start() - self._ssc_wait(t, 10.0, 0.01) - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + + def condition(): + self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + return True + self._eventually(condition, catch_assertions=True) + realCenters = array_sum(array(centers), axis=0) for i in range(5): modelCenters = stkm.latestModel().centers[0][i] @@ -1027,7 +1066,7 @@ def test_trainOn_model(self): stkm.setInitialCenters( centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) - # Create a toy dataset by setting a tiny offest for each point. + # Create a toy dataset by setting a tiny offset for each point. offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] batches = [] for offset in offsets: @@ -1037,14 +1076,15 @@ def test_trainOn_model(self): batches = [self.sc.parallelize(batch, 1) for batch in batches] input_stream = self.ssc.queueStream(batches) stkm.trainOn(input_stream) - t = time() self.ssc.start() # Give enough time to train the model. - self._ssc_wait(t, 6.0, 0.01) - finalModel = stkm.latestModel() - self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + def condition(): + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + return True + self._eventually(condition, catch_assertions=True) def test_predictOn_model(self): """Test that the model predicts correctly on toy data.""" @@ -1066,10 +1106,13 @@ def update(rdd): result.append(rdd_collect) predict_val.foreachRDD(update) - t = time() self.ssc.start() - self._ssc_wait(t, 6.0, 0.01) - self.assertEquals(result, [[0], [1], [2], [3]]) + + def condition(): + self.assertEquals(result, [[0], [1], [2], [3]]) + return True + + self._eventually(condition, catch_assertions=True) def test_trainOn_predictOn(self): """Test that prediction happens on the updated model.""" @@ -1095,10 +1138,13 @@ def collect(rdd): predict_stream = stkm.predictOn(input_stream) predict_stream.foreachRDD(collect) - t = time() self.ssc.start() - self._ssc_wait(t, 6.0, 0.01) - self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + + def condition(): + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + return True + + self._eventually(condition, catch_assertions=True) class LinearDataGeneratorTests(MLlibTestCase): @@ -1156,11 +1202,14 @@ def test_parameter_accuracy(self): slr.setInitialWeights([0.0]) slr.trainOn(input_stream) - t = time() self.ssc.start() - self._ssc_wait(t, 20.0, 0.01) - rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 - self.assertAlmostEqual(rel, 0.1, 1) + + def condition(): + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + return True + + self._eventually(condition, catch_assertions=True) def test_convergence(self): """ @@ -1179,13 +1228,18 @@ def test_convergence(self): input_stream.foreachRDD( lambda x: models.append(slr.latestModel().weights[0])) - t = time() self.ssc.start() - self._ssc_wait(t, 15.0, 0.01) + + def condition(): + self.assertEquals(len(models), len(input_batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, 60.0, catch_assertions=True) + t_models = array(models) diff = t_models[1:] - t_models[:-1] - - # Test that weights improve with a small tolerance, + # Test that weights improve with a small tolerance self.assertTrue(all(diff >= -0.1)) self.assertTrue(array_sum(diff > 0) > 1) @@ -1208,9 +1262,13 @@ def test_predictions(self): predict_stream = slr.predictOnValues(input_stream) true_predicted = [] predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) - t = time() self.ssc.start() - self._ssc_wait(t, 5.0, 0.01) + + def condition(): + self.assertEquals(len(true_predicted), len(input_batches)) + return True + + self._eventually(condition, catch_assertions=True) # Test that the accuracy error is no more than 0.4 on each batch. for batch in true_predicted: @@ -1242,12 +1300,17 @@ def collect_errors(rdd): ps = slr.predictOnValues(predict_stream) ps.foreachRDD(lambda x: collect_errors(x)) - t = time() self.ssc.start() - self._ssc_wait(t, 20.0, 0.01) - # Test that the improvement in error is atleast 0.3 - self.assertTrue(errors[1] - errors[-1] > 0.3) + def condition(): + # Test that the improvement in error is > 0.3 + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 0.3) + if len(errors) >= 3 and errors[1] - errors[-1] > 0.3: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): @@ -1274,13 +1337,16 @@ def test_parameter_accuracy(self): batches.append(sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) - t = time() slr.trainOn(input_stream) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - self.assertArrayAlmostEqual( - slr.latestModel().weights.array, [10., 10.], 1) - self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + + def condition(): + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + return True + + self._eventually(condition, catch_assertions=True) def test_parameter_convergence(self): """Test that the model parameters improve with streaming data.""" @@ -1298,13 +1364,18 @@ def test_parameter_convergence(self): input_stream = self.ssc.queueStream(batches) input_stream.foreachRDD( lambda x: model_weights.append(slr.latestModel().weights[0])) - t = time() slr.trainOn(input_stream) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - model_weights = array(model_weights) - diff = model_weights[1:] - model_weights[:-1] + def condition(): + self.assertEquals(len(model_weights), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) + + w = array(model_weights) + diff = w[1:] - w[:-1] self.assertTrue(all(diff >= -0.1)) def test_prediction(self): @@ -1323,13 +1394,18 @@ def test_prediction(self): sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) - t = time() output_stream = slr.predictOnValues(input_stream) samples = [] output_stream.foreachRDD(lambda x: samples.append(x.collect())) self.ssc.start() - self._ssc_wait(t, 5, 0.01) + + def condition(): + self.assertEquals(len(samples), len(batches)) + return True + + # We want all batches to finish for this test. + self._eventually(condition, catch_assertions=True) # Test that mean absolute error on each batch is less than 0.1 for batch in samples: @@ -1350,22 +1426,27 @@ def test_train_prediction(self): predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] - mean_absolute_errors = [] + errors = [] def func(rdd): true, predicted = zip(*rdd.collect()) - mean_absolute_errors.append(mean(abs(true) - abs(predicted))) + errors.append(mean(abs(true) - abs(predicted))) - model_weights = [] input_stream = self.ssc.queueStream(batches) output_stream = self.ssc.queueStream(predict_batches) - t = time() slr.trainOn(input_stream) output_stream = slr.predictOnValues(output_stream) output_stream.foreachRDD(func) self.ssc.start() - self._ssc_wait(t, 10, 0.01) - self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) + + def condition(): + if len(errors) == len(predict_batches): + self.assertGreater(errors[1] - errors[-1], 2) + if len(errors) >= 3 and errors[1] - errors[-1] > 2: + return True + return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) + + self._eventually(condition) class MLUtilsTests(MLlibTestCase): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 916de2d6fcdb..10a1e4b3eb0f 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -300,6 +300,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, :param: seed Random Seed :param: eps Used to scale the noise. If eps is set high, the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints """ weights = [float(weight) for weight in weights] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fa8e0a0574a6..9ef60a7e2c84 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -700,7 +700,7 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}, checkCode=False): + def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. @@ -709,6 +709,9 @@ def pipe(self, command, env={}, checkCode=False): :param checkCode: whether or not to check the return value of the shell command. """ + if env is None: + env = dict() + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 917de24f3536..0ef46c44644a 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -39,7 +39,7 @@ try: import pandas has_pandas = True -except ImportError: +except Exception: has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 47d5a6a43a84..025811f51929 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -566,8 +566,7 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - - if isinstance(on[0], basestring): + elif isinstance(on[0], basestring): jdf = self._jdf.join(other._jdf, self._jseq(on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" @@ -723,8 +722,6 @@ def __getitem__(self, item): [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): - if item not in self.columns: - raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): @@ -1205,7 +1202,9 @@ def freqItems(self, cols, support=None): @ignore_unicode_prefix @since(1.3) def withColumn(self, colName, col): - """Returns a new :class:`DataFrame` by adding a column. + """ + Returns a new :class:`DataFrame` by adding a column or replacing the + existing column that has the same name. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. @@ -1213,7 +1212,8 @@ def withColumn(self, colName, col): >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ - return self.select('*', col.alias(colName)) + assert isinstance(col, Column), "col should be Column" + return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) @ignore_unicode_prefix @since(1.3) @@ -1226,10 +1226,7 @@ def withColumnRenamed(self, existing, new): >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ - cols = [Column(_to_java_column(c)).alias(new) - if c == existing else c - for c in self.columns] - return self.select(*cols) + return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) @since(1.4) @ignore_unicode_prefix diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324..4b74a501521a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -530,9 +530,10 @@ def lead(col, count=1, default=None): @since(1.4) def ntile(n): """ - Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in - a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will - get 2, the third row will get 3, and the fourth row will get 1... + Window function: returns the ntile group id (from 1 to `n` inclusive) + in an ordered window partition. For example, if `n` is 4, the first + quarter of the rows will get value 1, the second quarter will get 2, + the third quarter will get 3, and the last quarter will get 4. This is equivalent to the NTILE function in SQL. @@ -885,10 +886,10 @@ def crc32(col): returns the value as a bigint. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() - [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')] + [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.md5(_to_java_column(col))) + return Column(sc._jvm.functions.crc32(_to_java_column(col))) @ignore_unicode_prefix diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bf6ac084bbbf..78247c8fa737 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -182,7 +182,7 @@ def orc(self, path): @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, - predicates=None, properties={}): + predicates=None, properties=None): """ Construct a :class:`DataFrame` representing the database table accessible via JDBC URL `url` named `table` and connection `properties`. @@ -208,6 +208,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar should be included. :return: a DataFrame """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) @@ -427,7 +429,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties={}): + def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -445,6 +447,8 @@ def jdbc(self, url, table, mode=None, properties={}): arbitrary string tag/value. Normally at least a "user" and "password" property should be included. """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 38c83c427a74..aacfb34c7761 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -770,7 +770,7 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): @@ -794,7 +794,9 @@ def test_field_accessor(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) self.assertEqual("v", df.select(df.d["k"]).first()[0]) def test_infer_long_type(self): @@ -1033,6 +1035,10 @@ def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) + def test_with_column_with_existing_name(self): + keys = self.df.withColumn("key", self.df.key).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + class HiveContextSQLTests(ReusedPySparkTestCase): @@ -1124,5 +1130,28 @@ def test_window_functions(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_without_partitionBy(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c083bf89905b..ed4e5b594bd6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -467,9 +467,11 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ Construct a StructType by adding new elements to it to define the schema. The method accepts either: + a) A single parameter which is a StructField object. b) Between 2 and 4 parameters as (name, data_type, nullable (optional), - metadata(optional). The data_type parameter may be either a String or a DataType object + metadata(optional). The data_type parameter may be either a String or a + DataType object. >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True),\ diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index c74745c726a0..eaf4d7e98620 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -64,7 +64,7 @@ def orderBy(*cols): Creates a :class:`WindowSpec` with the partitioning defined. """ sc = SparkContext._active_spark_context - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 944fa414b0c0..0fee3b209682 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -30,7 +30,9 @@ class StatCounter(object): - def __init__(self, values=[]): + def __init__(self, values=None): + if values is None: + values = list() self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ac5ba69e8dbb..e3ba70e4e5e8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -86,6 +86,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ # TODO: support checkpoint in HDFS if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): @@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc): cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -182,6 +231,7 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ @@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8dcb9645cdc6..698336cfce18 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -610,7 +610,10 @@ def __init__(self, prev, func): self.is_checkpointed = False self._jdstream_val = None - if (isinstance(prev, TransformedDStream) and + # Using type() to avoid folding the functions and compacting the DStreams which is not + # not strictly a object of TransformedDStream. + # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges(). + if (type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func self.func = lambda t, rdd: func(t, prev_func(t, rdd)) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cbb573f226bb..c0cdc50d8d42 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -31,7 +31,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class FlumeUtils(object): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 33dd596335b4..8a814c64c042 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -29,13 +29,15 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -52,6 +54,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ + if kafkaParams is None: + kafkaParams = dict() kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -77,7 +81,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @staticmethod - def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -105,6 +109,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, :param valueDecoder: A function used to decode value (default is utf8_decoder). :return: A DStream object """ + if fromOffsets is None: + fromOffsets = dict() if not isinstance(topics, list): raise TypeError("topics should be list") if not isinstance(kafkaParams, dict): @@ -129,7 +135,7 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod - def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + def createRDD(sc, kafkaParams, offsetRanges, leaders=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -145,6 +151,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object """ + if leaders is None: + leaders = dict() if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") if not isinstance(offsetRanges, list): diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index bcfe2703fecf..34be5880e170 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -26,7 +26,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KinesisUtils(object): diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 000000000000..f06598971c54 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cd544b2144e..214d5be43900 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,6 +24,7 @@ import tempfile import random import struct +import shutil from functools import reduce if sys.version_info[:2] <= (2, 6): @@ -40,6 +41,7 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream @@ -58,12 +60,21 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop(False) + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) def wait_for(self, result, n): start_time = time.time() @@ -441,6 +452,7 @@ def test_reduce_by_invalid_window(self): class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): inputs = [range(1, x) for x in range(101)] @@ -514,10 +526,85 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + class CheckpointTests(unittest.TestCase): - def test_get_or_create(self): + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -532,11 +619,12 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc.start() def check_output(n): while not os.listdir(outputd): @@ -551,7 +639,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -567,13 +655,37 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) - ssc.stop(True, True) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(cpd) # delete checkpoint directory + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertTrue(self.setupCalled) + self.ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): @@ -738,7 +850,9 @@ def transformWithOffsetRanges(rdd): offsetRanges.append(o) return rdd - stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + # Test whether it is ok mixing KafkaTransformedDStream and TransformedDStream together, + # only the TransformedDstreams can be folded together. + stream.transform(transformWithOffsetRanges).map(lambda kv: kv[1]).count().pprint() self.ssc.start() self.wait_for(offsetRanges, 1) @@ -893,6 +1007,68 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -908,8 +1084,10 @@ def test_kinesis_stream_api(self): "awsAccessKey", "awsSecretKey") def test_kinesis_stream(self): - if os.environ.get('ENABLE_KINESIS_TESTS') != '1': - print("Skip test_kinesis_stream") + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) return import random @@ -950,6 +1128,7 @@ def get_output(_, rdd): traceback.print_exc() raise finally: + self.ssc.stop(False) kinesisTestUtils.deleteStream() kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) @@ -964,7 +1143,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) @@ -982,10 +1161,45 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = glob.glob( + os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " + "remove all but one") % mqtt_assembly_dir) + else: + return jars[0] + + +def search_mqtt_test_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") + jars = glob.glob( + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please " + "remove all but one") % mqtt_test_dir) else: return jars[0] @@ -997,11 +1211,7 @@ def search_kinesis_asl_assembly_jar(): os.path.join(kinesis_asl_assembly_dir, "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) if not jars: - raise Exception( - ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % - kinesis_asl_assembly_dir) + "You need to build Spark with " - "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " - "or 'build/mvn -Pkinesis-asl package' before running this test") + return None elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " "remove all but one") % kinesis_asl_assembly_dir) @@ -1009,11 +1219,48 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - unittest.main() + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, + CheckpointTests, KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 93df9002be37..42c2f8b75933 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -146,5 +146,5 @@ def process(): java_port = int(sys.stdin.readline()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("a+", 65536) + sock_file = sock.makefile("rwb", 65536) main(sock_file, sock_file) diff --git a/python/run-tests.py b/python/run-tests.py index cc560779373b..fd56c7ab6e0e 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -158,7 +158,7 @@ def main(): else: log_level = logging.INFO logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") - LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 0374846d7167..501dff090313 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; @@ -59,7 +59,7 @@ public class UnsafeArrayData extends ArrayData { private int sizeInBytes; private int getElementOffset(int ordinal) { - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -157,7 +157,7 @@ public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return false; - return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, baseOffset + offset); } @Override @@ -165,7 +165,7 @@ public byte getByte(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, baseOffset + offset); } @Override @@ -173,7 +173,7 @@ public short getShort(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, baseOffset + offset); } @Override @@ -181,7 +181,7 @@ public int getInt(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, baseOffset + offset); } @Override @@ -189,7 +189,7 @@ public long getLong(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, baseOffset + offset); } @Override @@ -197,7 +197,7 @@ public float getFloat(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, baseOffset + offset); } @Override @@ -205,7 +205,7 @@ public double getDouble(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, baseOffset + offset); } @Override @@ -215,7 +215,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (offset < 0) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long value = Platform.getLong(baseObject, baseOffset + offset); return Decimal.apply(value, precision, scale); } else { final byte[] bytes = getBinary(ordinal); @@ -241,12 +241,7 @@ public byte[] getBinary(int ordinal) { if (offset < 0) return null; final int size = getElementSize(offset, ordinal); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size); + Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; } @@ -255,9 +250,8 @@ public CalendarInterval getInterval(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } @@ -307,27 +301,16 @@ public boolean equals(Object other) { } public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); final byte[] arrayDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - arrayDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + Platform.copyMemory( + baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java index b521b703389d..7b03185a30e3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class UnsafeReaders { public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); final UnsafeArrayData array = new UnsafeArrayData(); // Skip the first 4 bytes. array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); @@ -32,9 +32,9 @@ public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); final int valueArraySize = numBytes - 8 - keyArraySize; final UnsafeArrayData keyArray = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e829acb6285f..6c020045c311 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,7 +27,7 @@ import java.util.Set; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -169,7 +169,7 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } @Override @@ -179,7 +179,7 @@ public void setNullAt(int i) { // To preserve row equality, zero out the value when setting the column to null. // Since this row does does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + Platform.putLong(baseObject, getFieldOffset(i), 0); } @Override @@ -191,14 +191,14 @@ public void update(int ordinal, Object value) { public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + Platform.putInt(baseObject, getFieldOffset(ordinal), value); } @Override public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + Platform.putLong(baseObject, getFieldOffset(ordinal), value); } @Override @@ -208,28 +208,28 @@ public void setDouble(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + Platform.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + Platform.putShort(baseObject, getFieldOffset(ordinal), value); } @Override public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + Platform.putByte(baseObject, getFieldOffset(ordinal), value); } @Override @@ -239,7 +239,7 @@ public void setFloat(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } /** @@ -263,24 +263,23 @@ public void setDecimal(int ordinal, Decimal value, int precision) { long cursor = getLong(ordinal) >>> 32; assert cursor > 0 : "invalid cursor " + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); if (value == null) { setNullAt(ordinal); // keep the offset for future update - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); + Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); - assert(mag.length <= 4); + byte[] bytes = integer.toByteArray(); + assert(bytes.length <= 16); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - baseObject, baseOffset + cursor, mag.length * 4); - setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, baseObject, baseOffset + cursor, bytes.length); + setLong(ordinal, (cursor << 32) | ((long) bytes.length)); } } } @@ -336,47 +335,45 @@ public boolean isNullAt(int ordinal) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); + return Platform.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); + return Platform.getByte(baseObject, getFieldOffset(ordinal)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); + return Platform.getShort(baseObject, getFieldOffset(ordinal)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + return Platform.getInt(baseObject, getFieldOffset(ordinal)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + return Platform.getLong(baseObject, getFieldOffset(ordinal)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + return Platform.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + return Platform.getDouble(baseObject, getFieldOffset(ordinal)); } - private static byte[] EMPTY = new byte[0]; - @Override public Decimal getDecimal(int ordinal, int precision, int scale) { if (isNullAt(ordinal)) { @@ -385,20 +382,10 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { return Decimal.apply(getLong(ordinal), precision, scale); } else { - long offsetAndSize = getLong(ordinal); - long offset = offsetAndSize >>> 32; - int signum = ((int) (offsetAndSize & 0xfff) >> 8); - assert signum >=0 && signum <= 2 : "invalid signum " + signum; - int size = (int) (offsetAndSize & 0xff); - int[] mag = new int[size]; - PlatformDependent.copyMemory(baseObject, baseOffset + offset, - mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); - - // create a BigInteger using signum and mag - BigInteger v = new BigInteger(0, EMPTY); // create the initial object - PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); - PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); - return Decimal.apply(new BigDecimal(v, scale), precision, scale); + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); } } @@ -420,11 +407,11 @@ public byte[] getBinary(int ordinal) { final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset + offset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, size ); return bytes; @@ -438,9 +425,8 @@ public CalendarInterval getInterval(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } } @@ -491,14 +477,14 @@ public MapData getMap(int ordinal) { public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(); final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset, rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); return rowCopy; } @@ -518,18 +504,13 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. this.baseObject = new byte[row.sizeInBytes]; } - PlatformDependent.copyMemory( - row.baseObject, - row.baseOffset, - this.baseObject, - this.baseOffset, - row.sizeInBytes - ); + Platform.copyMemory( + row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -544,19 +525,15 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; while (dataRemaining > 0) { int toTransfer = Math.min(writeBuffer.length, dataRemaining); - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + baseObject, rowReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); out.write(writeBuffer, 0, toTransfer); rowReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -584,13 +561,12 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; - PlatformDependent.copyMemory(baseObject, baseOffset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return bytes; } } @@ -600,8 +576,7 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(java.lang.Long.toHexString( - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); + build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); @@ -619,12 +594,6 @@ public boolean anyNull() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 28e7ec0a0f12..2f43db68a750 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.CalendarInterval; @@ -58,29 +58,26 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(base, offset, 0L); - PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + Platform.putLong(base, offset, 0L); + Platform.putLong(base, offset + 8, 0L); if (input == null) { target.setNullAt(ordinal); // keep the offset and length for update int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; - PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + Platform.putLong(base, target.getBaseOffset() + fieldOffset, ((long) cursor) << 32); return SIZE; } final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); - int signum = integer.signum() + 1; - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); - assert(mag.length <= 4); + byte[] bytes = integer.toByteArray(); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - base, target.getBaseOffset() + cursor, mag.length * 4); + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, bytes.length); // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); + target.setLong(ordinal, (((long) cursor) << 32) | (long) bytes.length); return SIZE; } @@ -99,8 +96,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -125,8 +121,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -167,8 +162,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -191,8 +185,8 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + Platform.putLong(target.getBaseObject(), offset, input.months); + Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds); // Set the fixed length portion. target.setLong(ordinal, ((long) cursor) << 32); @@ -212,12 +206,11 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayDa final long offset = target.getBaseOffset() + cursor; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -247,14 +240,13 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes of key array to the variable length portion. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java index 0e8e405d055d..cd83695fca03 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -36,17 +35,11 @@ public static void writeToMemory( // zero-out the padding bytes // if ((numBytes & 0x07) > 0) { -// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); // } // Write the UnsafeData to the target memory. - PlatformDependent.copyMemory( - inputObject, - inputOffset, - targetObject, - targetOffset, - numBytes - ); + Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes); } public static int getRoundedSize(int size) { @@ -68,16 +61,11 @@ public static int write(Object targetObject, long targetOffset, Decimal input) { assert(numBytes <= 16); // zero-out the bytes - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + Platform.putLong(targetObject, targetOffset, 0L); + Platform.putLong(targetObject, targetOffset + 8, 0L); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, - targetOffset, - numBytes); - + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return 16; } } @@ -111,8 +99,7 @@ public static int write(Object targetObject, long targetOffset, byte[] input) { final int numBytes = input.length; // Write the bytes to the variable length portion. - writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, targetOffset, numBytes); + writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return getRoundedSize(numBytes); } @@ -144,11 +131,9 @@ public static int getSize(UnsafeRow input) { } public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); - + Platform.putLong(targetObject, targetOffset, input.months); + Platform.putLong(targetObject, targetOffset + 8, input.microseconds); return 16; } } @@ -165,11 +150,11 @@ public static int write(Object targetObject, long targetOffset, UnsafeArrayData final int numBytes = input.getSizeInBytes(); // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset + 4, numBytes); + writeToMemory( + input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes); return getRoundedSize(numBytes + 4); } @@ -190,9 +175,9 @@ public static int write(Object targetObject, long targetOffset, UnsafeMapData in final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + Platform.putInt(targetObject, targetOffset + 4, keysNumBytes); // Write the bytes of key array to the variable length portion. writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index a5ae2b973652..1d27182912c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -157,7 +157,7 @@ public UnsafeRow next() { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: - PlatformDependent.throwException(e); + Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); }; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 40159aaf14d3..ec895af9c303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -364,31 +364,10 @@ trait Row extends Serializable { false } - /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on external row with internal row. - */ - protected def canEqual(other: Row) = { - // Note that `Row` is not only the interface of external row but also the parent - // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent - // call `equals` on external row with internal row. - // `InternalRow` overrides canEqual, and these two canEquals together makes sure that - // equality check between external Row and InternalRow will always fail. - // In the future, InternalRow should not extend Row. In that case, we can remove these - // canEqual methods. - !other.isInstanceOf[InternalRow] - } - override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) return false val other = o.asInstanceOf[Row] - if (!canEqual(other)) { - throw new UnsupportedOperationException( - "cannot check equality between external and internal rows") - } - if (other eq null) return false if (length != other.length) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index aebcdeb9d070..d701559bf2d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -25,7 +25,9 @@ private[sql] case class TableIdentifier(table: String, database: Option[String] def toSeq: Seq[String] = database.toSeq :+ table - override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + override def toString: String = quotedString + + def quotedString: String = toSeq.map("`" + _ + "`").mkString(".") def unquotedString: String = toSeq.mkString(".") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa4..d0eb9c2c90bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -82,7 +82,9 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic) + PullOutNondeterministic), + Batch("Cleanup", fixedPoint, + CleanupAliases) ) /** @@ -146,8 +148,6 @@ class Analyzer( child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne - case g: GetStructField => Alias(g, g.field.name)() - case g: GetArrayStructFields => Alias(g, g.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u case other => Alias(other, s"_c$i")() @@ -384,9 +384,7 @@ class Analyzer( case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) - } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -412,11 +410,6 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def trimUnresolvedAlias(ne: NamedExpression) = ne match { - case UnresolvedAlias(child) => child - case other => other - } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -426,7 +419,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + plan.resolve(nameParts, resolver).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -931,6 +924,7 @@ class Analyzer( */ object PullOutNondeterministic extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -968,3 +962,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] { case Subquery(_, child) => child } } + +/** + * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level + * expression in Project(project list) or Aggregate(aggregate expressions) or + * Window(window expressions). + */ +object CleanupAliases extends Rule[LogicalPlan] { + private def trimAliases(e: Expression): Expression = { + var stop = false + e.transformDown { + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } + + def trimNonTopLevelAliases(e: Expression): Expression = e match { + case a: Alias => + Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + case other => trimAliases(other) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Project(projectList, child) => + val cleanedProjectList = + projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Project(cleanedProjectList, child) + + case Aggregate(grouping, aggs, child) => + val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) + + case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val cleanedWindowExprs = + windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) + Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + + case other => + var stop = false + other transformExpressionsDown { + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5766e6a2dd51..503c4f4b20f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} @@ -55,12 +56,15 @@ trait Catalog { def refreshTable(tableIdent: TableIdentifier): Unit + // TODO: Refactor it in the work of SPARK-10104 def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit + // TODO: Refactor it in the work of SPARK-10104 def unregisterTable(tableIdentifier: Seq[String]): Unit def unregisterAllTables(): Unit + // TODO: Refactor it in the work of SPARK-10104 protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { if (conf.caseSensitiveAnalysis) { tableIdentifier @@ -69,6 +73,7 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDbTableName(tableIdent: Seq[String]): String = { val size = tableIdent.size if (size <= 2) { @@ -78,9 +83,22 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { (tableIdent.lift(tableIdent.size - 2), tableIdent.last) } + + /** + * It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]]. + * We use this method to check it. + */ + protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = { + if (tableIdentifier.length > 1) { + throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + + "for temporary tables. If the table name has dots (.) in it, please quote the " + + "table name with backticks (`).") + } + } } class SimpleCatalog(val conf: CatalystConf) extends Catalog { @@ -89,11 +107,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.remove(getDbTableName(tableIdent)) } @@ -103,6 +123,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def tableExists(tableIdentifier: Seq[String]): Boolean = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.containsKey(getDbTableName(tableIdent)) } @@ -110,6 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def lookupRelation( tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) val table = tables.get(tableFullName) @@ -149,7 +171,13 @@ trait OverrideCatalog extends Catalog { abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - overrides.get(getDBTable(tableIdent)) match { + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } + overriddenTable match { case Some(_) => true case None => super.tableExists(tableIdentifier) } @@ -159,7 +187,12 @@ trait OverrideCatalog extends Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) - val overriddenTable = overrides.get(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -171,20 +204,8 @@ trait OverrideCatalog extends Catalog { } abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val dbName = if (conf.caseSensitiveAnalysis) { - databaseName - } else { - if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None - } - - val temporaryTables = overrides.filter { - // If a temporary table does not have an associated database, we should return its name. - case ((None, _), _) => true - // If a temporary table does have an associated database, we should return it if the database - // matches the given database name. - case ((db: Some[String], _), _) if db == dbName => true - case _ => false - }.map { + // We always return all temporary tables. + val temporaryTables = overrides.map { case ((_, tableName), _) => (tableName, true) }.toSeq @@ -194,13 +215,19 @@ trait OverrideCatalog extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) overrides.put(getDBTable(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.remove(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + // If tableIdentifier has more than one parts, it is not a temporary table + // and we do not need to do anything at here. + if (tableIdentifier.length == 1) { + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.remove(getDBTable(tableIdent)) + } } override def unregisterAllTables(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 970f3c8282c8..2cb067f4aac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -164,7 +164,7 @@ object HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}") newType } } @@ -371,8 +371,8 @@ object HiveTypeCoercion { DecimalType.bounded(range + scale, scale) } - private def changePrecision(e: Expression, dataType: DataType): Expression = { - ChangeDecimalPrecision(Cast(e, dataType)) + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -383,36 +383,42 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Add(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) - Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) + val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - Divide(changePrecision(e1, dt), changePrecision(e2, dt)) + val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), + max(6, s1 + p2 + 1)) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) + CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => @@ -599,7 +605,7 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -616,7 +622,7 @@ object HiveTypeCoercion { case c: CaseKeyWhen if c.childrenResolved && !c.resolved => val maybeCommonType = - findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) + findWiderCommonType((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => @@ -633,6 +639,7 @@ object HiveTypeCoercion { */ object IfCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 946c5a9c04f1..2db954257be3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal.ZERO) + buildCast[Decimal](_, !_.isZero) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType) case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) - case DecimalType() => + case dt: DecimalType => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) - case LongType => - b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case t: IntegralType => + b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) + case x: FractionalType => b => try { - changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { case _: NumberFormatException => null } @@ -447,7 +447,7 @@ case class Cast(child: Expression, dataType: DataType) case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) - case decimal: DecimalType => castToDecimalCode(from, decimal) + case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) @@ -528,17 +528,18 @@ case class Cast(child: Expression, dataType: DataType) } """ - private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + private[this] def castToDecimalCode( + from: DataType, + target: DecimalType, + ctx: CodeGenContext): CastFunction = { + val tmp = ctx.freshName("tmpDecimal") from match { case StringType => (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - new scala.math.BigDecimal( - new java.math.BigDecimal($c.toString()))); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -546,13 +547,8 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = null; - if ($c) { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); - } else { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); - } - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); + ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive @@ -561,33 +557,29 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c.clone(); + ${changePrecision(tmp, target, evPrim, evNull)} """ - case LongType => + case x: IntegralType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set($c); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply((long) $c); + ${changePrecision(tmp, target, evPrim, evNull)} """ - case x: NumericType => + case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index b76757c93523..d3560df0792e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -37,20 +37,20 @@ class JoinedRow extends InternalRow { } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { + def withLeft(newLeft: InternalRow): JoinedRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { + def withRight(newRight: InternalRow): JoinedRow = { row2 = newRight this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 2cf8312ea59a..5e8298aaaa9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A var result: Any = null override def update(input: InternalRow): Unit = { + // We ignore null values. if (result == null) { result = expr.eval(input) } @@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag var result: Any = null override def update(input: InternalRow): Unit = { - result = input + val value = expr.eval(input) + // We ignore null values. + if (value != null) { + result = value + } } override def eval(input: InternalRow): Any = { - if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null + result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8..bf96248feaef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -27,7 +28,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ @@ -265,6 +266,45 @@ class CodeGenContext { def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + */ + def splitExpressions(row: String, expressions: Seq[String]): String = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (code <- expressions) { + // We can't know how many byte code will be generated, so use the number of bytes as limit + if (blockBuilder.length > 64 * 1000) { + blocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(code) + } + blocks.append(blockBuilder.toString()) + + if (blocks.length == 1) { + // inline execution if only one block + blocks.head + } else { + val apply = freshName("apply") + val functions = blocks.zipWithIndex.map { case (body, i) => + val name = s"${apply}_$i" + val code = s""" + |private void $name(InternalRow $row) { + | $body + |} + """.stripMargin + addNewFunction(name, code) + name + } + + functions.map(name => s"$name($row);").mkString("\n") + } + } } /** @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString + }.mkString("\n") } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString + ctx.mutableStates.map(_._3).mkString("\n") } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString + ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** @@ -328,8 +368,10 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( - classOf[PlatformDependent].getName, + classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ac58423cd884..b4d4df8934bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val projectionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ } } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline execution if codesize limit was not broken - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } + val allProjections = ctx.splitExpressions("i", projectionCodes) val code = s""" public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + return new SpecificMutableProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificProjection($exprType[] expr) { + public SpecificMutableProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allProjections return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ef08ddf041af..7ad352d7ce3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types._ @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") + // These expressions could be splitted into multiple functions + ctx.addMutableState("Object[]", values, s"this.$values = null;") + val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $values[$i] = ${converter.primitive}; } """ - }.mkString("\n") - + } + val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - final Object[] $values = new Object[${schema.length}]; - $fieldWriters + this.$values = new Object[${schema.length}]; + $allFields final InternalRow $output = new $rowClass($values); """ @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline it if we have only one block - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } - + val allExpressions = ctx.splitExpressions("i", expressionCodes) val code = s""" public Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); @@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificSafeProjection($exprType[] expr) { expressions = expr; @@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allExpressions return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d8912df694a1..b570fe86db1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - private val PlatformDependent = classOf[PlatformDependent].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case NullType => true @@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + $DecimalWriter.getSize(${ev.primitive})" + s"$DecimalWriter.getSize(${ev.primitive})" case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})" case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})" case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" + s"${ev.isNull} ? 0 : 16" case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})" case _: ArrayType => - s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})" case _: MapType => - s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})" case _ => "" } @@ -125,64 +122,68 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro */ private def createCodeForStruct( ctx: CodeGenContext, + row: String, inputs: Seq[GeneratedExpressionCode], inputTypes: Seq[DataType]): GeneratedExpressionCode = { + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();") val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val numBytes = ctx.freshName("numBytes") + ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") + ctx.addMutableState("int", cursor, s"this.$cursor = 0;") + val tmp = ctx.freshName("tmpBuffer") - val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => - createConvertCode(ctx, input, dt) - } - - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - if (dt.isInstanceOf[DecimalType]) { - // Can't call setNullAt() for DecimalType + val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => + val ev = createConvertCode(ctx, input, dt) + val growBuffer = if (!UnsafeRow.isFixedLength(dt)) { + val numBytes = ctx.freshName("numBytes") s""" + int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmp = new byte[$numBytes * 2]; + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmp; + } + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); + """ + } else { + "" + } + val update = dt match { + case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType + s""" if (${ev.isNull}) { - $cursor += $DecimalWriter.write($output, $i, $cursor, null); + $cursor += $DecimalWriter.write($output, $i, $cursor, null); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ - } else { - s""" + case _ => + s""" if (${ev.isNull}) { $output.setNullAt($i); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ } - }.mkString("\n") + s""" + ${ev.code} + $growBuffer + $update + """ + } val code = s""" - ${convertedFields.map(_.code).mkString("\n")} - - final int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, - $numBytes); - - int $cursor = $fixedSize; - - $fieldWriters + $cursor = $fixedSize; + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) } @@ -223,7 +224,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // go through the input array to calculate how many bytes we need. val calculateNumBytes = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" @@ -236,6 +237,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => val writer = getWriter(elementType) val elementSize = s"$writer.getSize($elements[$index])" + // TODO(davies): avoid the copy val unsafeType = elementType match { case _: StructType => "UnsafeRow" case _: ArrayType => "UnsafeArrayData" @@ -248,8 +250,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => "" } + val newElements = if (elementType == BinaryType) { + s"new byte[$numElements][]" + } else { + s"new $unsafeType[$numElements]" + } s""" - final $unsafeType[] $elements = new $unsafeType[$numElements]; + final $unsafeType[] $elements = $newElements; for (int $index = 0; $index < $numElements; $index++) { ${convertedElement.code} if (!${convertedElement.isNull}) { @@ -261,21 +268,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeElement = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" - $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - $PlatformDependent.UNSAFE.putLong( + Platform.putLong( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -284,7 +291,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -318,23 +325,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - $PlatformDependent.UNSAFE.putInt( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - -$cursor); + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - $PlatformDependent.UNSAFE.putInt( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - $cursor); - + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } } $output.pointTo( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } @@ -400,7 +400,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldIsNull = s"$tmp.isNullAt($i)" GeneratedExpressionCode("", fieldIsNull, getFieldCode) } - val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; @@ -427,7 +427,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) - createCodeForStruct(ctx, exprEvals, exprTypes) + createCodeForStruct(ctx, "i", exprEvals, exprTypes) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 30b51dd83fa9..da91ff29537b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -52,9 +52,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val offset = PlatformDependent.BYTE_ARRAY_OFFSET - val getLong = "PlatformDependent.UNSAFE.getLong" - val putLong = "PlatformDependent.UNSAFE.putLong" + val offset = Platform.BYTE_ARRAY_OFFSET + val getLong = "Platform.getLong" + val putLong = "Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -96,7 +96,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U var cursor = offset + outputBitsetWords * 8 val copyFixedLengthRow1 = s""" |// Copy fixed length data for row1 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${bitset1Words * 8}, | buf, $cursor, | ${schema1.size * 8}); @@ -106,7 +106,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy fixed length portion from row 2 ----------------------- // val copyFixedLengthRow2 = s""" |// Copy fixed length data for row2 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${bitset2Words * 8}, | buf, $cursor, | ${schema2.size * 8}); @@ -118,7 +118,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, | numBytesVariableRow1); @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, | numBytesVariableRow2); @@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } - }.mkString + }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4a071e663e0d..1c546719730b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { @@ -208,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) @@ -246,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala similarity index 72% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index adb33e4c8d4a..b7be12f7aa74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -66,10 +66,44 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un * An expression used to wrap the children when promote the precision of DecimalType to avoid * promote multiple times. */ -case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { +case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "change_decimal_precision" + override def prettyName: String = "promote_precision" +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not, returns null if not. + */ +case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Decimal].clone() + if (d.changePrecision(dataType.precision, dataType.scale)) { + d + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + val tmp = ctx.freshName("tmp") + s""" + | Decimal $tmp = $eval.clone(); + | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { + | ${ev.primitive} = $tmp; + | } else { + | ${ev.isNull} = true; + | } + """.stripMargin + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d474853355e5..c0845e1a0102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala new file mode 100644 index 000000000000..6dff28a7cde4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.regex.{MatchResult, Pattern} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait StringRegexExpression extends ImplicitCastInputTypes { + self: BinaryExpression => + + def escape(v: String): String + def matches(regex: Pattern, str: String): Boolean + + override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + // try cache the pattern for Literal + private lazy val cache: Pattern = right match { + case x @ Literal(value: String, StringType) => compile(value) + case _ => null + } + + protected def compile(str: String): Pattern = if (str == null) { + null + } else { + // Let it raise exception if couldn't compile the regex string + Pattern.compile(escape(str)) + } + + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) + if(regex == null) { + null + } else { + matches(regex, input1.asInstanceOf[UTF8String].toString) + } + } +} + + +/** + * Simple RegEx pattern matching function + */ +case class Like(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression with CodegenFallback { + + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) + + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } +} + + +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression with CodegenFallback { + + override def escape(v: String): String = v + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } +} + + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, (str, pattern) => + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") + } + + override def prettyName: String = "split" +} + + +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + UTF8String.fromString(result.toString) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState("UTF8String", + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep.clone(); + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; + """ + }) + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 + } + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala similarity index 73% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 76666bd6b3d2..48d02bb53450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -21,13 +21,9 @@ import java.text.DecimalFormat import java.util.Arrays import java.util.{Map => JMap, HashMap} import java.util.Locale -import java.util.regex.{MatchResult, Pattern} - -import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -76,7 +72,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ case class ConcatWs(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, s"$prettyName requires at least one argument.") @@ -118,149 +114,48 @@ case class ConcatWs(children: Seq[Expression]) boolean ${ev.isNull} = ${ev.primitive} == null; """ } else { - // Contains a mix of strings and arrays. Fall back to interpreted mode for now. - super.genCode(ctx, ev) - } - } -} - + val array = ctx.freshName("array") + val varargNum = ctx.freshName("varargNum") + val idxInVararg = ctx.freshName("idxInVararg") -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => - - def escape(v: String): String - def matches(regex: Pattern, str: String): Boolean - - override def dataType: DataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - // try cache the pattern for Literal - private lazy val cache: Pattern = right match { - case x @ Literal(value: String, StringType) => compile(value) - case _ => null - } - - protected def compile(str: String): Pattern = if (str == null) { - null - } else { - // Let it raise exception if couldn't compile the regex string - Pattern.compile(escape(str)) - } - - protected def pattern(str: String) = if (cache == null) compile(str) else cache - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val regex = pattern(input2.asInstanceOf[UTF8String].toString()) - if(regex == null) { - null - } else { - matches(regex, input1.asInstanceOf[UTF8String].toString()) - } - } -} - -/** - * Simple RegEx pattern matching function - */ -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = StringUtils.escapeLikeRegex(v) - - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); - """ - }) - } - } -} - - -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = v - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" + val evals = children.map(_.gen(ctx)) + val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => + child.dataType match { + case StringType => + ("", // we count all the StringType arguments num at once below. + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.primitive};") + case _: ArrayType => + val size = ctx.freshName("n") + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.primitive}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.primitive}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.primitive, StringType, "j")}; + } + } + """) + } + }.unzip - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile(rightStr); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); - """ - }) + evals.map(_.code).mkString("\n") + + s""" + int $varargNum = ${children.count(_.dataType == StringType) - 1}; + int $idxInVararg = 0; + ${varargCount.mkString("\n")} + UTF8String[] $array = new UTF8String[$varargNum]; + ${varargBuild.mkString("\n")} + UTF8String ${ev.primitive} = UTF8String.concatWs(${evals.head.primitive}, $array); + boolean ${ev.isNull} = ${ev.primitive} == null; + """ } } } - trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -305,7 +200,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ImplicitCastInputTypes { +trait StringPredicate extends Predicate with ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -322,7 +217,7 @@ trait StringComparison extends ImplicitCastInputTypes { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -333,7 +228,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -344,7 +239,7 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -550,13 +445,14 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil + override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -582,6 +478,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val substrGen = substr.gen(ctx) + val strGen = str.gen(ctx) + val startGen = start.gen(ctx) + s""" + int ${ev.primitive} = 0; + boolean ${ev.isNull} = false; + ${startGen.code} + if (!${startGen.isNull}) { + ${substrGen.code} + if (!${substrGen.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.indexOf(${substrGen.primitive}, + ${startGen.primitive}) + 1; + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "locate" } @@ -769,32 +690,6 @@ case class StringSpace(child: Expression) override def prettyName: String = "space" } -/** - * Splits str around pat (pattern is a regular expression). - */ -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = pattern - override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) - new GenericArrayData(strings.asInstanceOf[Array[Any]]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => - // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") - } - - override def prettyName: String = "split" -} - object Substring { def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = { if (pos > bytes.length) { @@ -1013,7 +908,7 @@ case class Decode(bin: Expression, charset: Expression) try { ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); } """) } @@ -1043,168 +938,11 @@ case class Encode(value: Expression, charset: Expression) try { ${ev.primitive} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); }""") } } -/** - * Replace all substrings of str that match regexp with rep. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - // last replacement string, we don't want to convert a UTF8String => java.langString every time. - @transient private var lastReplacement: String = _ - @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private val result: StringBuffer = new StringBuffer - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) - - UTF8String.fromString(result.toString) - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = subject :: regexp :: rep :: Nil - override def prettyName: String = "regexp_replace" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") - - val classNamePattern = classOf[Pattern].getCanonicalName - val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - - ctx.addMutableState("UTF8String", - termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, - termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", - termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") - - nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!$rep.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); - ${ev.isNull} = false; - """ - }) - } -} - -/** - * Extract a specific(idx) group identified by a Java regex. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - def this(s: Expression, r: Expression) = this(s, r, Literal(1)) - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } else { - UTF8String.EMPTY_UTF8 - } - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil - override def prettyName: String = "regexp_extract" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - val classNamePattern = classOf[Pattern].getCanonicalName - - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - - nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - java.util.regex.Matcher m = - ${termPattern}.matcher($subject.toString()); - if (m.find()) { - java.util.regex.MatchResult mr = m.toMatchResult(); - ${ev.primitive} = UTF8String.fromString(mr.group($idx)); - ${ev.isNull} = false; - } else { - ${ev.primitive} = UTF8String.EMPTY_UTF8; - ${ev.isNull} = false; - }""" - }) - } -} - /** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or @@ -1306,8 +1044,8 @@ case class FormatNumber(x: Expression, d: Expression) $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; $numberFormat.applyPattern($dFormat.toPattern()); - ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } + ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else { ${ev.primitive} = null; ${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ab5ac2c61e3..854463dd11c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] { * * - Inserting Projections beneath the following operators: * - Aggregate + * - Generate * - Project <- Join * - LeftSemiJoin */ @@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + // Eliminate attributes that are not needed to calculate the Generate. + case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => + g.copy(child = Project(g.references.toSeq, g.child)) + + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => + p.copy(child = g.copy(join = false)) + + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { + p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + } + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) if (a.outputSet -- p.references).nonEmpty => Project( @@ -260,8 +276,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] { val substitutedProjection = projectList1.map(_.transform { case a: Attribute => aliasMap.getOrElse(a, a) }).asInstanceOf[Seq[NamedExpression]] - - Project(substitutedProjection, child) + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + Project(cleanedProjection, child) } } } @@ -353,7 +372,7 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => e } - case e: StringComparison => e.children match { + case e: StringPredicate => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c290e6acb361..9bb466ac2d29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and wrap it with UnresolvedAlias which will be removed later. + // and aliased it with the last part of the name. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as - // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - Some(UnresolvedAlias(fieldExprs)) + Some(Alias(fieldExprs, nestedFields.last)()) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7c404722d811..73b8261260ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,7 @@ case class Window( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - (projectList ++ windowExpressions).map(_.toAttribute) + projectList ++ windowExpressions.map(_.toAttribute) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a89a90b735a..5ac3f1f5b0ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -216,26 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.toSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } + } /** @@ -257,15 +254,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8b824511a79d..f80d2a93241d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,22 +17,25 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide -import scala.collection.mutable - object RuleExecutor { - protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + protected val timeMap = AtomicLongMap.create[String]() /** Resets statistics about time spent running specific rules */ def resetTime(): Unit = timeMap.clear() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val maxSize = timeMap.keys.map(_.toString.length).max - timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + val map = timeMap.asMap().asScala + val maxSize = map.keys.map(_.toString.length).max + map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" }.mkString("\n") } @@ -79,7 +82,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { logTrace( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2..5770f59b5307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b20..7bcd623b3f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 624c3f3d7fa9..c988f1d1b972 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal(MATH_CONTEXT) + decimalVal } else { - BigDecimal(longVal, _scale)(MATH_CONTEXT) + BigDecimal(longVal, _scale) } } @@ -267,7 +267,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale) } else { - Decimal(toBigDecimal + that.toBigDecimal, precision, scale) + Decimal(toBigDecimal + that.toBigDecimal) } } @@ -275,18 +275,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale) } else { - Decimal(toBigDecimal - that.toBigDecimal, precision, scale) + Decimal(toBigDecimal - that.toBigDecimal) } } // HiveTypeCoercion will take care of the precision, scale of result - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827c..00461e529ca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9cbc207538d4..d8968ef80639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 000000000000..5b802ccc637d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 63b475b6366c..7065adce04bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types._ @@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { +class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ def errorTest( @@ -149,6 +145,11 @@ class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { UnresolvedTestPlan(), "unresolved" :: Nil) + errorTest( + "SPARK-9955: correct error message for aggregate", + // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. + testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), + "cannot resolve 'bad_column'" :: Nil) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c944bc69e25b..1e0cc81dae97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { + val a = testRelation.output.head + var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) + var expected = testRelation.select((a + 1 + 2).as("col")) + checkAnalysis(plan, expected) + + plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) + expected = testRelation.groupBy(a)((min(a) + 1).as("col")) + checkAnalysis(plan, expected) + + // CreateStruct is a special case that we should not trim Alias for it. + plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index cbdf453f600a..6f33ab733b61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -285,6 +285,17 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Cast(Literal(100L), DecimalType(22, 2)), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + ) } test("type coercion simplification for equal to") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a1f15e4f0f25..72285c6a2419 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -52,6 +52,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) + } } test("- (UnaryMinus)") { @@ -71,6 +75,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(UnaryMinus, tpe) + } } test("- (Minus)") { @@ -85,6 +93,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) + } } test("* (Multiply)") { @@ -99,6 +111,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) + } } test("/ (Divide) basic") { @@ -111,6 +127,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) + } } test("/ (Divide) for integral type") { @@ -144,6 +164,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0) checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) + + // TODO: the following lines would fail the test due to inconsistency result of interpret + // and codegen for remainder between giant values, seems like a numeric stability issue + // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + // checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) + // } } test("Abs") { @@ -161,6 +187,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(negativeIntLit), - negativeInt) checkEvaluation(Abs(positiveLongLit), positiveLong) checkEvaluation(Abs(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe) + } } test("MaxOf basic") { @@ -175,6 +205,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort) checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt) checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MaxOf, tpe, tpe) + } } test("MaxOf for atomic type") { @@ -196,6 +230,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort) checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt) checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) + } } test("MinOf for atomic type") { @@ -222,4 +260,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt) checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong) } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 4fc1c0615359..3a310c0e9a7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -45,6 +45,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) + } } test("BitwiseAnd") { @@ -68,6 +72,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort & negativeShort).toShort) checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) + } } test("BitwiseOr") { @@ -91,6 +99,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort | negativeShort).toShort) checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) + } } test("BitwiseXor") { @@ -110,10 +122,13 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) checkEvaluation(BitwiseXor(Literal(1), nullLit), null) checkEvaluation(BitwiseXor(nullLit, nullLit), null) - checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), (positiveShort ^ negativeShort).toShort) checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index d26bcdb2902a..0df673bb9fa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -66,6 +66,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toLong, TimestampType) testIf(_.toString, StringType) + + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(If, BooleanType, dt, dt) + } } test("case when") { @@ -176,6 +180,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) + } } test("function greatest") { @@ -218,6 +226,9 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) - } + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f9b73f1a75e7..610d39e8493c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -60,6 +60,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) } test("Year") { @@ -79,6 +80,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) } test("Quarter") { @@ -98,6 +100,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) } test("Month") { @@ -117,6 +120,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Month, DateType) } test("Day / DayOfMonth") { @@ -135,6 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } + checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth, DateType) } test("Seconds") { @@ -149,6 +154,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } + checkConsistencyBetweenInterpretedAndCodegen(Second, TimestampType) } test("WeekOfYear") { @@ -157,6 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } test("DateFormat") { @@ -184,6 +191,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Hour, TimestampType) } test("Minute") { @@ -200,6 +208,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.MINUTE)) } } + checkConsistencyBetweenInterpretedAndCodegen(Minute, TimestampType) } test("date_add") { @@ -218,6 +227,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) checkEvaluation( DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType) } test("date_sub") { @@ -236,6 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) checkEvaluation( DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) + checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType) } test("time_add") { @@ -254,6 +265,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeAdd, TimestampType, CalendarIntervalType) } test("time_sub") { @@ -277,6 +289,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeSub, TimestampType, CalendarIntervalType) } test("add_months") { @@ -296,6 +309,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) checkEvaluation( AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) + checkConsistencyBetweenInterpretedAndCodegen(AddMonths, DateType, IntegerType) } test("months_between") { @@ -320,6 +334,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthsBetween(t, tnull), null) checkEvaluation(MonthsBetween(tnull, t), null) checkEvaluation(MonthsBetween(tnull, tnull), null) + checkConsistencyBetweenInterpretedAndCodegen(MonthsBetween, TimestampType, TimestampType) } test("last_day") { @@ -337,6 +352,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(LastDay, DateType) } test("next_day") { @@ -370,6 +386,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ToDate(Literal(Date.valueOf("2015-07-22"))), DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) checkEvaluation(ToDate(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType) } test("function trunc") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala new file mode 100644 index 000000000000..511f0307901d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} + + +class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("UnscaledValue") { + val d1 = Decimal("10.1") + checkEvaluation(UnscaledValue(Literal(d1)), 101L) + val d2 = Decimal(101, 3, 1) + checkEvaluation(UnscaledValue(Literal(d2)), 101L) + checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null) + } + + test("MakeDecimal") { + checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) + checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + } + + test("PromotePrecision") { + val d1 = Decimal("10.1") + checkEvaluation(PromotePrecision(Literal(d1)), d1) + val d2 = Decimal(101, 3, 1) + checkEvaluation(PromotePrecision(Literal(d2)), d2) + checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null) + } + + test("CheckOverflow") { + val d1 = Decimal("10.1") + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + + val d2 = Decimal(101, 3, 1) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + + checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a41185b4d875..465f7d08aa14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -211,4 +214,111 @@ trait ExpressionEvalHelper { plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against unary expressions by feeding them arbitrary literals of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Expression => Expression, + dataType: DataType): Unit = { + forAll (LiteralGenerator.randomGen(dataType)) { (l: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against ternary expressions by feeding them arbitrary literals of `dataType1`, + * `dataType2` and `dataType3`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType, + dataType3: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2), + LiteralGenerator.randomGen(dataType3) + ) { (l1: Literal, l2: Literal, l3: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2, l3)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against expressions take Seq[Expression] as input by feeding them + * arbitrary length Seq of arbitrary literal of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Seq[Expression] => Expression, + dataType: DataType, + minNumElements: Int = 0): Unit = { + forAll (Gen.listOf(LiteralGenerator.randomGen(dataType))) { (literals: Seq[Literal]) => + whenever(literals.size >= minNumElements) { + cmpInterpretWithCodegen(EmptyRow, c(literals)) + } + } + } + + private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { + val interpret = try { + evaluate(expr, inputRow) + } catch { + case e: Exception => fail(s"Exception evaluating $expr", e) + } + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + expr) + val codegen = plan(inputRow).get(0, expr.dataType) + + if (!compareResults(interpret, codegen)) { + fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + } + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte] and Spread[Double]. + */ + private[this] def compareResults(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) + case (result: Double, expected: Double) if result.isNaN && expected.isNaN => + true + case (result: Float, expected: Float) if result.isNaN && expected.isNaN => + true + case _ => result == expected + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala new file mode 100644 index 000000000000..ee6d25157fc0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.Matchers +import org.scalatest.prop.GeneratorDrivenPropertyChecks + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Property is a high-level specification of behavior that should hold for a range of data points. + * + * For example, while we are evaluating a deterministic expression for some input, we should always + * hold the property that the result never changes, regardless of how we get the result, + * via interpreted or codegen. + * + * In ScalaTest, properties are specified as functions and the data points used to check properties + * can be supplied by either tables or generators. + * + * Generator-driven property checks are performed via integration with ScalaCheck. + * + * @example {{{ + * def toTest(i: Int): Boolean = if (i % 2 == 0) true else false + * + * import org.scalacheck.Gen + * + * test ("true if param is even") { + * val evenInts = for (n <- Gen.choose(-1000, 1000)) yield 2 * n + * forAll(evenInts) { (i: Int) => + * assert (toTest(i) === true) + * } + * } + * }}} + * + */ +object LiteralGenerator { + + lazy val byteLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) + + lazy val shortLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbShort.arbitrary } yield Literal.create(s, ShortType) + + lazy val integerLiteralGen: Gen[Literal] = + for { i <- Arbitrary.arbInt.arbitrary } yield Literal.create(i, IntegerType) + + lazy val longLiteralGen: Gen[Literal] = + for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType) + + lazy val floatLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2, + Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity) + } yield Literal.create(f, FloatType) + + lazy val doubleLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2, + Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) + } yield Literal.create(f, DoubleType) + + // TODO: decimal type + + lazy val stringLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) + + lazy val binaryLiteralGen: Gen[Literal] = + for { ab <- Gen.listOf[Byte](Arbitrary.arbByte.arbitrary) } + yield Literal.create(ab.toArray, BinaryType) + + lazy val booleanLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) + + lazy val dateLiteralGen: Gen[Literal] = + for { d <- Arbitrary.arbInt.arbitrary } yield Literal.create(new Date(d), DateType) + + lazy val timestampLiteralGen: Gen[Literal] = + for { t <- Arbitrary.arbLong.arbitrary } yield Literal.create(new Timestamp(t), TimestampType) + + lazy val calendarIntervalLiterGen: Gen[Literal] = + for { m <- Arbitrary.arbInt.arbitrary; s <- Arbitrary.arbLong.arbitrary} + yield Literal.create(new CalendarInterval(m, s), CalendarIntervalType) + + + // Sometimes, it would be quite expensive when unlimited value is used, + // for example, the `times` arguments for StringRepeat would hang the test 'forever' + // if it's tested against Int.MaxValue by ScalaCheck, therefore, use values from a limited + // range is more reasonable + lazy val limitedIntegerLiteralGen: Gen[Literal] = + for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType) + + def randomGen(dt: DataType): Gen[Literal] = { + dt match { + case ByteType => byteLiteralGen + case ShortType => shortLiteralGen + case IntegerType => integerLiteralGen + case LongType => longLiteralGen + case DoubleType => doubleLiteralGen + case FloatType => floatLiteralGen + case DateType => dateLiteralGen + case TimestampType => timestampLiteralGen + case BooleanType => booleanLiteralGen + case StringType => stringLiteralGen + case BinaryType => binaryLiteralGen + case CalendarIntervalType => calendarIntervalLiterGen + case dt => throw new IllegalArgumentException(s"not supported type $dt") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 033792eee6c0..90c59f240b54 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ - class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ @@ -184,60 +183,74 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("sin") { testUnary(Sin, math.sin) + checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) } test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) } test("sinh") { testUnary(Sinh, math.sinh) + checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) } test("cos") { testUnary(Cos, math.cos) + checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) } test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("cosh") { testUnary(Cosh, math.cosh) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) } test("tan") { testUnary(Tan, math.tan) + checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) } test("atan") { testUnary(Atan, math.atan) + checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) } test("tanh") { testUnary(Tanh, math.tanh) + checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) } test("toDegrees") { testUnary(ToDegrees, math.toDegrees) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("toRadians") { testUnary(ToRadians, math.toRadians) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) } test("cbrt") { testUnary(Cbrt, math.cbrt) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } test("ceil") { testUnary(Ceil, math.ceil) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) } test("floor") { testUnary(Floor, math.floor) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) } test("factorial") { @@ -247,37 +260,45 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) } test("rint") { testUnary(Rint, math.rint) + checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) } test("exp") { testUnary(Exp, math.exp) + checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) } test("expm1") { testUnary(Expm1, math.expm1) + checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) } test("signum") { testUnary[Double, Double](Signum, math.signum) + checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) } test("log") { testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) } test("log10") { testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) } test("log1p") { testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) } test("bin") { @@ -298,12 +319,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) + + checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (1 to 20).map(_ * 0.1)) testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) } test("sqrt") { @@ -313,11 +337,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) } test("shift left") { @@ -338,6 +364,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) } test("shift right") { @@ -358,6 +387,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) } test("shift right unsigned") { @@ -386,6 +418,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { negativeLong >>> positiveInt) checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), negativeLong >>> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) } test("hex") { @@ -400,6 +435,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on + Seq(LongType, BinaryType, StringType).foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) + } } test("unhex") { @@ -413,16 +451,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) checkEvaluation(Unhex(Literal("三重的")), null) - // scalastyle:on + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) } test("hypot") { testBinary(Hypot, math.hypot) + checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) } test("atan2") { testBinary(Atan2, math.atan2) + checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) } test("binary log") { @@ -454,6 +494,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } test("round") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index b524d0af14a6..75d17417e5a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,6 +29,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { @@ -37,6 +38,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { @@ -55,6 +57,6 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 7beef71845e4..54c04faddb47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -73,6 +73,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } + checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) + } + + test("AND, OR, EqualTo, EqualNullSafe consistency check") { + checkConsistencyBetweenInterpretedAndCodegen(And, BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Or, BooleanType, BooleanType) + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt) + } } booleanLogicTest("AND", And, @@ -180,6 +190,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + test("BinaryComparison consistency check") { + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(LessThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual, dt, dt) + } + } + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 426dc272471a..99e3b13ce8c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -673,7 +673,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) } - test("number format") { + test("format_number / FormatNumber") { checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index aff1bee99faa..796d60032e1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * A test suite for the bitset portion of the row concatenation. @@ -96,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala new file mode 100644 index 000000000000..098944a9f4fc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A test suite for generated projections + */ +class GeneratedProjectionSuite extends SparkFunSuite { + + test("generated projections on wider table") { + val N = 1000 + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs)() + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + + test("generated unsafe projection with array of binary") { + val row = InternalRow( + Array[Byte](1, 2), + new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) + val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] + + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow: UnsafeRow = unsafeProj(row) + assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) + assert(unsafeRow.getArray(1).isNullAt(1)) + assert(unsafeRow.getArray(1).getBinary(1) === null) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) + + val safeProj = FromUnsafeProjection(fields) + val row2 = safeProj(unsafeRow) + assert(row2 === row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala new file mode 100644 index 000000000000..dbebcb86809d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types.StringType + +class ColumnPruningSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Column pruning", FixedPoint(100), + ColumnPruning) :: Nil + } + + test("Column pruning for Generate when Generate.join = false") { + val input = LocalRelation('a.int, 'b.array(StringType)) + + val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Generate(Explode('b), false, false, None, 's.string :: Nil, + Project('b.attr :: Nil, input)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning for Generate when Generate.join = true") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + + val query = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + Project(Seq('a, 'c), + input))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Turn Generate.join to false if possible") { + val input = LocalRelation('b.array(StringType)) + + val query = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), false, false, None, 's.string :: Nil, + input)).analyze + + comparePlans(optimized, correctAnswer) + } + + // todo: add more tests for column pruning +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d7..706ecd29d135 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 417df006ab7c..ed2c641d63e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -46,6 +46,25 @@ object DataTypeTestUtils { */ val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + // TODO: remove this once we find out how to handle decimal properly in property check + val numericTypeWithoutDecimal: Set[DataType] = integralType ++ Set(DoubleType, FloatType) + + /** + * Instances of all [[NumericType]]s and [[CalendarIntervalType]] + */ + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + + /** + * All the types that support ordering + */ + val ordered: Set[DataType] = + numericTypeWithoutDecimal + BooleanType + TimestampType + DateType + StringType + BinaryType + + /** + * All the types that we can use in a property check + */ + val propertyCheckSupported: Set[DataType] = ordered + /** * Instances of all [[AtomicType]]s. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 00218f213054..09511ff35f78 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -27,7 +29,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -138,7 +140,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo unsafeGroupingKeyRow.getBaseOffset(), unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length ); if (!putSucceeded) { @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 69d6784713a2..7db6b7ff50f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -225,7 +225,7 @@ public boolean next() throws IOException { int recordLen = underlying.getRecordLength(); // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cc32d4b72748..ca50000b4756 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,3 @@ -org.apache.spark.sql.jdbc.DefaultSource -org.apache.spark.sql.json.DefaultSource -org.apache.spark.sql.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.jdbc.DefaultSource +org.apache.spark.sql.execution.datasources.json.DefaultSource +org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 27bd08484734..807bc8c30c12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as("colB")) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = Alias(expr, alias)() + def as(alias: String): Column = expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as('colB)) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = Alias(expr, alias.name)() + def as(alias: Symbol): Column = expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } /** * Gives the column an alias with metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 570b8b2d5928..d6688b24ae7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,12 +34,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -634,6 +634,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ @@ -641,6 +642,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ @@ -1131,7 +1133,8 @@ class DataFrame private[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Returns a new [[DataFrame]] by adding a column. + * Returns a new [[DataFrame]] by adding a column or replacing the existing column that has + * the same name. * @group dfops * @since 1.3.0 */ @@ -1149,6 +1152,23 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. @@ -1560,8 +1580,10 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: HadoopFsRelation) => - fsBasedRelation.paths.toSeq + case LogicalRelation(fsBasedRelation: FileRelation) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles }.flatten files.toSet.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 85f33c5e9952..6dc7bfe33349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, Partition} @@ -197,7 +197,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { table: String, parts: Array[Partition], connectionProperties: Properties): DataFrame = { - val relation = JDBCRelation(url, table, parts, connectionProperties)(sqlContext) + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + val relation = JDBCRelation(url, table, parts, props)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2a4992db09bc..ce8744b53175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.HadoopFsRelation @@ -218,7 +218,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { case _ => val cmd = CreateTableUsingAsSelect( - tableIdent.unquotedString, + tableIdent, source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), @@ -244,7 +244,13 @@ final class DataFrameWriter private[sql](df: DataFrame) { * should be included. */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { - val conn = JdbcUtils.createConnection(url, connectionProperties) + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + val conn = JdbcUtils.createConnection(url, props) try { var tableExists = JdbcUtils.tableExists(conn, table) @@ -264,7 +270,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Create the table if the table didn't exist. if (!tableExists) { - val schema = JDBCWriteDetails.schemaString(df, url) + val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } @@ -272,7 +278,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + JdbcUtils.saveTable(df, url, table, props) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 832572571cab..126c9c6f839c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.ui.{SQLListener, SQLTab} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -334,97 +332,21 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self /** * Converts $"col name" into an [[Column]]. * @since 1.3.0 */ + // This must live here to preserve binary compatibility with Spark < 1.5. implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { new ColumnName(sc.s(args: _*)) } } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } } + // scalastyle:on /** * :: Experimental :: @@ -662,9 +584,10 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = None, source, temporary = false, @@ -672,7 +595,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -707,9 +630,10 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = Some(schema), source, temporary = false, @@ -717,7 +641,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -802,7 +726,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + table(new SqlParser().parseTableIdentifier(tableName)) + } + + private def table(tableIdent: TableIdentifier): DataFrame = { DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) } @@ -873,7 +800,7 @@ class SQLContext(@transient val sparkContext: SparkContext) HashAggregation :: Aggregation :: LeftSemiJoin :: - HashJoin :: + EquiJoinSelection :: InMemoryScans :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 000000000000..47b6f80bed48 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + */ +private[sql] abstract class SQLImplicits { + protected def _sqlContext: SQLContext + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of case classes or tuples. + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1f270560d7bc..fc4d0938c533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -56,6 +56,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { /** * Register a user-defined aggregate function (UDAF). + * * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. * @return the registered UDAF. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d553bb6169ec..66d429bc0619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -120,9 +120,7 @@ private[sql] case class InMemoryRelation( new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => - val columnType = ColumnType(attribute.dataType) - val initialBufferSize = columnType.defaultSize * batchSize - ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression) + ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) }.toArray var rowCount = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index c91d960a0932..ca910a99db08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -270,20 +270,13 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - private val dictionary = { - // TODO Can we clean up this mess? Maybe move this to `DataType`? - implicit val classTag = { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) - } - - Array.fill(buffer.getInt()) { - columnType.extract(buffer) - } + private val dictionary: Array[Any] = { + val elementNum = buffer.getInt() + Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } override def next(row: MutableRow, ordinal: Int): Unit = { - columnType.setField(row, ordinal, dictionary(buffer.getShort())) + columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) } override def hasNext: Boolean = buffer.hasRemaining diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index e8c6a0f8f801..f3b6a3a5f4a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -45,6 +46,10 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil @@ -121,12 +126,15 @@ case class Aggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 var i = 0 while (i < buffer.length) { buffer(i).update(currentRow) @@ -142,6 +150,7 @@ case class Aggregate( i += 1 } + numOutputRows += 1 Iterator(resultProjection(aggregateResults)) } } else { @@ -152,6 +161,7 @@ case class Aggregate( var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 val currentGroup = groupingProjection(currentRow) var currentBuffer = hashTable.get(currentGroup) if (currentBuffer == null) { @@ -180,6 +190,7 @@ case class Aggregate( val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue + numOutputRows += 1 var i = 0 while (i < currentBuffer.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b89e634761eb..029f2264a6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -46,7 +46,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - GenerateUnsafeProjection.canSupport(child.schema) && + unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && !newPartitioning.isInstanceOf[RangePartitioning] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cae7ca5cbdc8..abb60cf12e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -99,8 +99,6 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], extraInformation: String) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - protected override def doExecute(): RDD[InternalRow] = rdd override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala new file mode 100644 index 000000000000..7a2a9eed5807 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * An interface for relations that are backed by files. When a class implements this interface, + * the list of paths that it returns will be returned to a user who calls `inputPaths` on any + * DataFrame that queries this relation. + */ +private[sql] trait FileRelation { + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 858dd85fd1fa..34e926e4582b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -30,8 +30,6 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) protected override def doExecute(): RDD[InternalRow] = rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala new file mode 100644 index 000000000000..7462dbc4eba3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + */ +private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ + def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ + def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var hasNextWasCalled: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + // Idempotency: + if (!hasNextWasCalled) { + _hasNext = rowIter.advanceNext() + hasNextWasCalled = true + } + _hasNext + } + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + hasNextWasCalled = false + rowIter.getRow + } +} + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 97f1323e9783..cee58218a885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1915496d1620..72f5450510a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -80,29 +80,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } - /** - * Whether track the number of rows output by this SparkPlan - */ - protected[sql] def trackNumOfRowsEnabled: Boolean = false - - private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = - if (trackNumOfRowsEnabled) { - Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - } - else { - Map.empty - } - /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics - - /** - * Return a IntSQLMetric according to the name. - */ - private[sql] def intMetric(name: String): IntSQLMetric = - metrics(name).asInstanceOf[IntSQLMetric] + private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** * Return a LongSQLMetric according to the name. @@ -156,15 +137,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() - if (trackNumOfRowsEnabled) { - val numRows = longMetric("numRows") - doExecute().map { row => - numRows += 1 - row - } - } else { - doExecute() - } + doExecute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4b9b5acea4d..4df53687a073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching join keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable and + * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join + * will be used. + * - Hash: will be chosen if neither of the above optimizations apply to this join. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], @@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = @@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( @@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } @@ -376,22 +395,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => + case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( - tableName, userSpecifiedSchema, provider, opts)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts)) :: Nil case c: CreateTableUsing if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query) + case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) if partitionsCols.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) => + case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => val cmd = CreateTempTableUsingAsSelect( - tableName, provider, Array.empty[String], mode, opts, query) + tableIdent, provider, Array.empty[String], mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 6c7e5cacc99e..5c18558f9bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { private[this] var rowSize: Int = dIn.readInt() + if (rowSize == EOF) dIn.close() override def hasNext: Boolean = rowSize != EOF @@ -116,9 +117,10 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed @@ -150,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f3..f4c14a9b3556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class SortBasedAggregate( @@ -38,6 +39,10 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = false @@ -63,6 +68,8 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. @@ -84,10 +91,13 @@ case class SortBasedAggregate( newProjection _, child.output, iter, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. + numOutputRows += 1 Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter @@ -98,6 +108,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad..73d50e07cf0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.KVIterator /** @@ -37,7 +38,9 @@ class SortBasedAggregationIterator( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends AggregationIterator( groupingKeyAttributes, valueAttributes, @@ -103,6 +106,7 @@ class SortBasedAggregationIterator( // Get the grouping key. val groupingKey = inputKVIterator.getKey val currentRow = inputKVIterator.getValue + numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { @@ -137,7 +141,7 @@ class SortBasedAggregationIterator( val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. initializeBuffer(sortBasedAggregationBuffer) - + numOutputRows += 1 outputRow } else { // no more result @@ -151,7 +155,7 @@ class SortBasedAggregationIterator( nextGroupingKey = inputKVIterator.getKey().copy() firstRowInNextGroup = inputKVIterator.getValue().copy() - + numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -181,7 +185,9 @@ object SortBasedAggregationIterator { newProjection: (Seq[Expression], Seq[Attribute]) => Projection, inputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { AggregationIterator.unsafeKVIterator( groupingExprs, @@ -202,7 +208,9 @@ object SortBasedAggregationIterator { initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9..ba379d358d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -35,6 +37,10 @@ case class TungstenAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -61,32 +67,58 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter, - testFallbackStartsAt) - - if (!hasInput && groupingExpressions.isEmpty) { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") + + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt, + numInputRows, + numOutputRows) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { + numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { @@ -95,11 +127,12 @@ case class TungstenAggregate( testFallbackStartsAt match { case None => val keyString = groupingExpressions.mkString("[", ",", "]") - val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 32160906c3bc..26fdbc83ef50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType /** @@ -71,8 +72,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -82,10 +81,14 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int]) + testFallbackStartsAt: Option[Int], + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[InternalRow] = null + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -345,22 +348,39 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { - while (!sortBased && inputIter.hasNext) { - val newInput = inputIter.next() - val groupingKey = groupProjection.apply(newInput) + assert(inputIter != null, "attempted to process input when iterator was null") + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + // Note that it would be better to eliminate the hash map entirely in the future. + val groupingKey = groupProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { + while (inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 processRow(buffer, newInput) } + } else { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } } } @@ -368,9 +388,11 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -407,6 +429,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -426,6 +449,11 @@ class TungstenAggregationIterator( case _ => false } + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() @@ -439,6 +467,7 @@ class TungstenAggregationIterator( // Process the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) buffer.copyFrom(initialAggregationBuffer) processRow(buffer, newInput) @@ -462,6 +491,7 @@ class TungstenAggregationIterator( // Insert the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) bufferExtractor(newInput) externalSorter.insertKV(groupingKey, buffer) @@ -581,27 +611,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[InternalRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -657,7 +693,7 @@ class TungstenAggregationIterator( TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) } - + numOutputRows += 1 res } else { // no more result @@ -666,21 +702,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(inputIter == null) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 7619f3ec9f0a..d43d3dd9ffaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -304,7 +304,7 @@ private[sql] case class ScalaUDAF( override def nullable: Boolean = true - override def dataType: DataType = udaf.returnDataType + override def dataType: DataType = udaf.dataType override def deterministic: Boolean = udaf.deterministic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 24950f26061f..3f68b05a24f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator @@ -41,11 +41,20 @@ import org.apache.spark.{HashPartitioner, SparkEnv} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map(reusableProjection) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map { row => + numRows += 1 + reusableProjection(row) + } + } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -57,19 +66,30 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + /** Rewrite the project list to use unsafe expressions as needed. */ + protected val unsafeProjectList = projectList.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + val project = UnsafeProjection.create(unsafeProjectList, child.output) + iter.map { row => + numRows += 1 + project(row) + } } - val project = UnsafeProjection.create(projectList, child.output) - iter.map(project) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -219,7 +239,10 @@ case class TakeOrderedAndProject( projectList: Option[Seq[NamedExpression]], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } override def outputPartitioning: Partitioning = SinglePartition @@ -245,6 +268,13 @@ case class TakeOrderedAndProject( protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala new file mode 100644 index 000000000000..f7a88b98c0b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -0,0 +1,185 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import scala.language.implicitConversions +import scala.util.matching.Regex + +import org.apache.spark.Logging +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ + + +/** + * A parser for foreign DDL commands. + */ +class DDLParser(parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { + + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { + try { + parse(input) + } catch { + case ddlException: DDLException => throw ddlException + case _ if !exceptionOnError => parseQuery(input) + case x: Throwable => throw x + } + } + + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") + + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable + + protected def start: Parser[LogicalPlan] = ddl + + /** + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... + */ + protected lazy val createTable: Parser[LogicalPlan] = { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableIdent, + provider, + temp.isDefined, + Array.empty[String], + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableIdent, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + } + } + + // This is the same as tableIdentifier in SqlParser. + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { + case e ~ tableIdent => + DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> tableIdentifier ^^ { + case tableIndet => + RefreshTable(tableIndet) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex $regex", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k, v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } + + StructField(columnName, typ, nullable = true, meta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 78a4acdf4b1b..2a4c40db8bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -349,6 +349,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 6ccde7693bd3..3b7dc2e8d021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -17,27 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources.InsertableRelation /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 66dfcc308cec..0a2007e15843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -270,6 +271,18 @@ private[sql] object PartitioningUtils { private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) + def validatePartitionColumnDataTypes( + schema: StructType, + partitionColumns: Array[String]): Unit = { + + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field => + field.dataType match { + case _: AtomicType => // OK + case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") + } + } + } + /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala new file mode 100644 index 000000000000..8fbaf3a3059d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -0,0 +1,207 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConversions._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Success, Failure, Try} + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + + +case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + + +object ResolvedDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider.", error) + } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: HadoopFsRelationProvider => + val maybePartitionsSchema = if (partitionColumns.isEmpty) { + None + } else { + Some(partitionColumnsSchema(schema, partitionColumns)) + } + + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable + + dataSource.createRelation( + sqlContext, + paths, + Some(dataSchema), + maybePartitionsSchema, + caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") + } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: HadoopFsRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") + } + } + new ResolvedDataSource(clazz, relation) + } + + def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String]): StructType = { + StructType(partitionColumns.map { col => + schema.find(_.name == col).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ + def apply( + sqlContext: SQLContext, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: HadoopFsRelationProvider => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions("path")) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns) + + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns)), + caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + sqlContext.executePlan( + InsertIntoHadoopFsRelation( + r, + data.logicalPlan, + mode)).toRdd + r + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + ResolvedDataSource(clazz, relation) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 2f11f4042240..78f48a5cd72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -58,6 +58,9 @@ private[sql] abstract class BaseWriterContainer( // This is only used on driver side. @transient private val jobContext: JobContext = job + private val speculationEnabled: Boolean = + relation.sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + // The following fields are initialized and used on both driver and executor side. @transient protected var outputCommitter: OutputCommitter = _ @transient private var jobId: JobID = _ @@ -126,10 +129,21 @@ private[sql] abstract class BaseWriterContainer( // associated with the file output format since it is not safe to use a custom // committer for appending. For example, in S3, direct parquet output committer may // leave partial data in the destination dir when the the appending job fails. + // + // See SPARK-8578 for more details logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + "for appending.") defaultOutputCommitter + } else if (speculationEnabled) { + // When speculation is enabled, it's not safe to use customized output committer classes, + // especially direct output committers (e.g. `DirectParquetOutputCommitter`). + // + // See SPARK-9899 for more details. + logInfo( + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + + "because spark.speculation is configured to be true.") + defaultOutputCommitter } else { val committerClass = context.getConfiguration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) @@ -217,6 +231,8 @@ private[sql] class DefaultWriterContainer( val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) writer.initConverter(dataSchema) + var writerClosed = false + // If anything below fails, we should abort the task. try { while (iterator.hasNext) { @@ -235,7 +251,10 @@ private[sql] class DefaultWriterContainer( def commitTask(): Unit = { try { assert(writer != null, "OutputWriter instance should have been initialized") - writer.close() + if (!writerClosed) { + writer.close() + writerClosed = true + } super.commitTask() } catch { case cause: Throwable => @@ -247,7 +266,10 @@ private[sql] class DefaultWriterContainer( def abortTask(): Unit = { try { - writer.close() + if (!writerClosed) { + writer.close() + writerClosed = true + } } finally { super.abortTask() } @@ -275,6 +297,8 @@ private[sql] class DynamicPartitionWriterContainer( val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) + var outputWritersCleared = false + // Returns the partition key given an input row val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) // Returns the data columns to be written given an input row @@ -287,7 +311,7 @@ private[sql] class DynamicPartitionWriterContainer( PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } // Returns the partition path given a partition key. @@ -379,8 +403,11 @@ private[sql] class DynamicPartitionWriterContainer( } def clearOutputWriters(): Unit = { - outputWriters.asScala.values.foreach(_.close()) - outputWriters.clear() + if (!outputWritersCleared) { + outputWriters.asScala.values.foreach(_.close()) + outputWriters.clear() + outputWritersCleared = true + } } def commitTask(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 8c2f297e4245..31d6b75e1347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,340 +17,12 @@ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader - -import scala.collection.Iterator -import scala.collection.JavaConversions._ -import scala.language.{existentials, implicitConversions} -import scala.util.{Failure, Success, Try} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource extends Logging { - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val provider2 = s"$provider.DefaultSource" - val loader = Utils.getContextOrSparkClassLoader - val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - - serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name") - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ -private[sql] case class DescribeCommand( +case class DescribeCommand( table: LogicalPlan, isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, @@ -370,7 +43,8 @@ private[sql] case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) } /** @@ -378,8 +52,8 @@ private[sql] case class DescribeCommand( * @param allowExisting If it is true, we will do nothing when the table already exists. * If it is false, an exception will be thrown */ -private[sql] case class CreateTableUsing( - tableName: String, +case class CreateTableUsing( + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, @@ -397,8 +71,9 @@ private[sql] case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -private[sql] case class CreateTableUsingAsSelect( - tableName: String, +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). +case class CreateTableUsingAsSelect( + tableIdent: TableIdentifier, provider: String, temporary: Boolean, partitionColumns: Array[String], @@ -406,12 +81,10 @@ private[sql] case class CreateTableUsingAsSelect( options: Map[String, String], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] - // TODO: Override resolved after we support databaseName. - // override lazy val resolved = databaseName != None && childrenResolved } -private[sql] case class CreateTempTableUsing( - tableName: String, +case class CreateTempTableUsing( + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -419,14 +92,16 @@ private[sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + Seq.empty[Row] } } -private[sql] case class CreateTempTableUsingAsSelect( - tableName: String, +case class CreateTempTableUsingAsSelect( + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -436,14 +111,15 @@ private[sql] case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] } } -private[sql] case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -472,7 +148,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) /** * Builds a map in which keys are case insensitive */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -490,4 +166,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. */ -protected[sql] class DDLException(message: String) extends Exception(message) +class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala new file mode 100644 index 000000000000..6773afc794f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala new file mode 100644 index 000000000000..7ccd61ed469e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Driver, DriverManager} + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper. + */ +object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala new file mode 100644 index 000000000000..18263fe227d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.util.Properties + +/** + * A wrapper for a JDBC Driver to work around SPARK-6913. + * + * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by + * Spark ClassLoader. + */ +class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + def getParentLogger: java.util.logging.Logger = { + throw new SQLFeatureNotSupportedException( + s"${this.getClass.getName}.getParentLogger is not yet implemented.") + } + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3cf70db6b7b0..e537d631f455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,7 +118,7 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = DriverManager.getConnection(url, properties) + val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() try { @@ -170,7 +171,8 @@ private[sql] object JDBCRDD extends Logging { * getConnector is run on the driver code, while the function it returns * is run on the executor. * - * @param driver - The class name of the JDBC driver for the given url. + * @param driver - The class name of the JDBC driver for the given url, or null if the class name + * is not necessary. * @param url - The JDBC url to connect to. * * @return A function that loads the driver and connects to the url. @@ -180,9 +182,8 @@ private[sql] object JDBCRDD extends Logging { try { if (driver != null) DriverRegistry.register(driver) } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } @@ -344,7 +345,6 @@ private[sql] class JDBCRDD( }).toArray } - /** * Runs the SQL query against the JDBC driver. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala similarity index 71% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 48d97ced9ca0..f9300dc2cb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties @@ -77,45 +77,6 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { - - def format(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - private[sql] case class JDBCRelation( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala new file mode 100644 index 000000000000..26788b2a4fd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, PreparedStatement} +import java.util.Properties + +import scala.util.Try + +import org.apache.spark.Logging +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Util functions for JDBC tables. + */ +object JdbcUtils extends Logging { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } + + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString()) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition( + getConnection: () => Connection, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int], + batchSize: Int): Iterator[Byte] = { + val conn = getConnection() + var committed = false + try { + conn.setAutoCommit(false) // Everything in the same db transaction. + val stmt = insertStatement(conn, table, rddSchema) + try { + var rowCount = 0 + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.addBatch() + rowCount += 1 + if (rowCount % batchSize == 0) { + stmt.executeBatch() + rowCount = 0 + } + } + if (rowCount > 0) { + stmt.executeBatch() + } + } finally { + stmt.close() + } + conn.commit() + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + conn.rollback() + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val dialect = JdbcDialects.get(url) + df.schema.fields foreach { field => { + val name = field.name + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case t: DecimalType => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( + s"Can't translate null value for field $field") + }) + } + + val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val batchSize = properties.getProperty("batchsize", "1000").toInt + df.foreachPartition { iterator => + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index ec5668c6b95a..b6f3410bad69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ private[sql] object InferSchema { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 5bb9e62310a5..114c8b211891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter @@ -39,9 +39,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "json" +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "json" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index d734e7e8904b..99ac7730bd1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow @@ -107,12 +107,12 @@ private[sql] object JacksonGenerator { v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() - case (MapType(kv, vv, _), v: Map[_, _]) => + case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) gen.writeEndObject() case (StructType(ty), v: InternalRow) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index b8fd3b9cc150..cd68bd667c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala index fde96852ce68..005546f37dda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 975fec101d9c..3f8353af6e2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { + // Called after `init()` when initializing Parquet record reader. override def prepareForRead( conf: Configuration, keyValueMetaData: JMap[String, String], @@ -51,19 +52,29 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // available if the target file is written by Spark SQL. .orElse(metadata.get(CatalystReadSupport.SPARK_METADATA_KEY)) }.map(StructType.fromString).getOrElse { - logDebug("Catalyst schema not available, falling back to Parquet schema") + logInfo("Catalyst schema not available, falling back to Parquet schema") toCatalyst.convert(parquetRequestedSchema) } - logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") + logInfo { + s"""Going to read the following fields from the Parquet file: + | + |Parquet form: + |$parquetRequestedSchema + |Catalyst form: + |$catalystRequestedSchema + """.stripMargin + } + new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) } + // Called before `prepareForRead()` when initializing Parquet record reader. override def init(context: InitContext): ReadContext = { val conf = context.getConfiguration // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst - // schema of this file from its the metadata. + // schema of this file from its metadata. val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) // Optional schema of requested columns, in the form of a string serialized from a Catalyst @@ -141,7 +152,6 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") new ReadContext(parquetRequestedSchema, metadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index 84f1dccfeb78..ed9e0aa65977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala similarity index 63% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 4fe8a39f20ab..d2c2db51769b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -15,20 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} +import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -42,6 +44,12 @@ import org.apache.spark.unsafe.types.UTF8String * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { + /** Called before a record field is being converted */ + def start(): Unit = () + + /** Called after a record field is being converted */ + def end(): Unit = () + def set(value: Any): Unit = () def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) @@ -55,13 +63,81 @@ private[parquet] trait ParentContainerUpdater { /** A no-op updater used for root converter (who doesn't have a parent). */ private[parquet] object NoopUpdater extends ParentContainerUpdater +private[parquet] trait HasParentContainerUpdater { + def updater: ParentContainerUpdater +} + /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. - * Since any Parquet record is also a struct, this converter can also be used as root converter. + * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + */ +private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater + +/** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ +private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) + extends PrimitiveConverter with HasParentContainerUpdater { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) +} + +/** + * A [[CatalystRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root + * converter. Take the following Parquet type as an example: + * {{{ + * message root { + * required int32 f1; + * optional group f2 { + * required double f21; + * optional binary f22 (utf8); + * } + * } + * }}} + * 5 converters will be created: + * + * - a root [[CatalystRowConverter]] for [[MessageType]] `root`, which contains: + * - a [[CatalystPrimitiveConverter]] for required [[INT_32]] field `f1`, and + * - a nested [[CatalystRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[CatalystPrimitiveConverter]] for required [[DOUBLE]] field `f21`, and + * - a [[CatalystStringConverter]] for optional [[UTF8]] string field `f22` * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. * + * @note Constructor argument [[parquetType]] refers to requested fields of the actual schema of the + * Parquet file being read, while constructor argument [[catalystType]] refers to requested + * fields of the global schema. The key difference is that, in case of schema merging, + * [[parquetType]] can be a subset of [[catalystType]]. For example, it's possible to have + * the following [[catalystType]]: + * {{{ + * new StructType() + * .add("f1", IntegerType, nullable = false) + * .add("f2", StringType, nullable = true) + * .add("f3", new StructType() + * .add("f31", DoubleType, nullable = false) + * .add("f32", IntegerType, nullable = true) + * .add("f33", StringType, nullable = true), nullable = false) + * }}} + * and the following [[parquetType]] (`f2` and `f32` are missing): + * {{{ + * message root { + * required int32 f1; + * required group f3 { + * required double f31; + * optional binary f33 (utf8); + * } + * } + * }}} + * * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type * @param updater An updater which propagates converted field values to the parent container @@ -70,7 +146,16 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) with Logging { + + logDebug( + s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -89,14 +174,29 @@ private[parquet] class CatalystRowConverter( /** * Represents the converted row object once an entire Parquet record is converted. - * - * @todo Uses [[UnsafeRow]] for better performance. */ val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) // Converters for each field. - private val fieldConverters: Array[Converter] = { - parquetType.getFields.zip(catalystType).zipWithIndex.map { + private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { + // In case of schema merging, `parquetType` can be a subset of `catalystType`. We need to pad + // those missing fields and create converters for them, although values of these fields are + // always null. + val paddedParquetFields = { + val parquetFields = parquetType.getFields + val parquetFieldNames = parquetFields.map(_.getName).toSet + val missingFields = catalystType.filterNot(f => parquetFieldNames.contains(f.name)) + + // We don't need to worry about feature flag arguments like `assumeBinaryIsString` when + // creating the schema converter here, since values of missing fields are always null. + val toParquet = new CatalystSchemaConverter() + + (parquetFields ++ missingFields.map(toParquet.convertField)).sortBy { f => + catalystType.indexWhere(_.name == f.getName) + } + } + + paddedParquetFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) @@ -105,11 +205,19 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def end(): Unit = updater.set(currentRow) + override def end(): Unit = { + var i = 0 + while (i < currentRow.numFields) { + fieldConverters(i).updater.end() + i += 1 + } + updater.set(currentRow) + } override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { + fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } @@ -122,20 +230,20 @@ private[parquet] class CatalystRowConverter( private def newConverter( parquetType: Type, catalystType: DataType, - updater: ParentContainerUpdater): Converter = { + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new CatalystPrimitiveConverter(updater) case ByteType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } @@ -148,7 +256,7 @@ private[parquet] class CatalystRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -164,13 +272,23 @@ private[parquet] class CatalystRowConverter( } case DateType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) } } + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + case t: ArrayType => new CatalystArrayConverter(parquetType.asGroupType(), t, updater) @@ -195,27 +313,11 @@ private[parquet] class CatalystRowConverter( } } - /** - * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types - * are handled by this converter. Parquet primitive types are only a subset of those of Spark - * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. - */ - private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { - - override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) - override def addInt(value: Int): Unit = updater.setInt(value) - override def addLong(value: Long): Unit = updater.setLong(value) - override def addFloat(value: Float): Unit = updater.setFloat(value) - override def addDouble(value: Double): Unit = updater.setDouble(value) - override def addBinary(value: Binary): Unit = updater.set(value.getBytes) - } - /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -242,7 +344,7 @@ private[parquet] class CatalystRowConverter( private final class CatalystDecimalConverter( decimalType: DecimalType, updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { @@ -306,7 +408,7 @@ private[parquet] class CatalystRowConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ @@ -372,9 +474,15 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = converter - override def end(): Unit = currentArray += currentElement + override def end(): Unit = { + converter.updater.end() + currentArray += currentElement + } - override def start(): Unit = currentElement = null + override def start(): Unit = { + converter.updater.start() + currentElement = null + } } } @@ -383,7 +491,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -446,4 +554,61 @@ private[parquet] class CatalystRowConverter( } } } + + private trait RepeatedConverter { + private var currentArray: ArrayBuffer[Any] = _ + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index b12149dcf1c9..535f0684e97f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ @@ -25,7 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ -import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} @@ -72,18 +72,9 @@ private[parquet] class CatalystSchemaConverter( followParquetFormatSpec = conf.followParquetFormatSpec) def this(conf: Configuration) = this( - assumeBinaryIsString = - conf.getBoolean( - SQLConf.PARQUET_BINARY_AS_STRING.key, - SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), - assumeInt96IsTimestamp = - conf.getBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), - followParquetFormatSpec = - conf.getBoolean( - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, + assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, + followParquetFormatSpec = conf.get(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -100,8 +91,11 @@ private[parquet] class CatalystSchemaConverter( StructField(field.getName, convertField(field), nullable = false) case REPEATED => - throw new AnalysisException( - s"REPEATED not supported outside LIST or MAP. Type: $field") + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertField(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1551afd7b7bf..2c6b914328b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala index 6ed3580af072..ccd7ebf319af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{MapData, ArrayData} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala similarity index 68% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d57b789f5c1c..c74c8388632f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.compat.FilterCompat -import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} -import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.OriginalType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ @@ -38,12 +37,6 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap { filter => - createFilter(filter) - }.reduceOption(FilterApi.and).map(FilterCompat.get) - } - case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { @@ -197,11 +190,23 @@ private[sql] object ParquetFilters { def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + relaxParquetValidTypeMap + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, // which can be casted to `false` implicitly. Please refer to the `eval` method of these // operators and the `SimplifyFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + predicate match { case sources.IsNull(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) @@ -213,6 +218,11 @@ private[sql] object ParquetFilters { case sources.Not(sources.EqualTo(name, value)) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.EqualNullSafe(name, value) => + makeEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.Not(sources.EqualNullSafe(name, value)) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.LessThan(name, value) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) case sources.LessThanOrEqual(name, value) => @@ -239,94 +249,35 @@ private[sql] object ParquetFilters { } } - /** - * Converts Catalyst predicate expressions to Parquet filter predicates. - * - * @todo This can be removed once we get rid of the old Parquet support. - */ - def createFilter(predicate: Expression): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `SimplifyFilters` rule for details. - predicate match { - case IsNull(NamedExpression(name, dataType)) => - makeEq.lift(dataType).map(_(name, null)) - case IsNotNull(NamedExpression(name, dataType)) => - makeNotEq.lift(dataType).map(_(name, null)) - - case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeEq.lift(dataType).map(_(name, value)) - - case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - - case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGt.lift(dataType).map(_(name, value)) - - case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - - case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLt.lift(dataType).map(_(name, value)) - - case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - - case And(lhs, rhs) => - (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) - - case Or(lhs, rhs) => - for { - lhsFilter <- createFilter(lhs) - rhsFilter <- createFilter(rhs) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case Not(pred) => - createFilter(pred).map(FilterApi.not) - - case InSet(NamedExpression(name, dataType), valueSet) => - makeInSet.lift(dataType).map(_(name, valueSet)) - - case _ => None - } + // !! HACK ALERT !! + // + // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to + // parquet-mr 1.8.1 or higher versions. + // + // In Parquet, not all types of columns can be used for filter push-down optimization. The set + // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and + // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be + // pushed down. + // + // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps + // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus, + // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly + // legal except that it fails the `ValidTypeMap` check. + // + // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue. + private lazy val relaxParquetValidTypeMap: Unit = { + val constructor = Class + .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor") + .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType]) + + constructor.setAccessible(true) + val enumTypeDescriptor = constructor + .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM) + .asInstanceOf[AnyRef] + + val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get + addMethod.setAccessible(true) + addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b6db71b5b8a6..bbf682aec0f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.logging.{Level, Logger => JLogger} +import java.util.logging.{Logger => JLogger} import java.util.{List => JList} import scala.collection.JavaConversions._ @@ -26,32 +26,33 @@ import scala.collection.mutable import scala.util.{Failure, Try} import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetRecordReader, _} import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ParquetLog} +import org.apache.parquet.{Log => ApacheParquetLog} +import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} -import org.apache.spark.rdd.RDD._ +import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "parquet" + override def shortName(): String = "parquet" override def createRelation( sqlContext: SQLContext, @@ -209,6 +210,13 @@ private[sql] class ParquetRelation( override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) + // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible + val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[DirectParquetOutputCommitter].getCanonicalName) + } + val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, @@ -274,12 +282,18 @@ private[sql] class ParquetRelation( val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + // Parquet row group size. We will use this value as the value for + // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value + // of these flags are smaller than the parquet row group size. + val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) + // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = ParquetRelation.initializeLocalJobFunc( requiredColumns, filters, dataSchema, + parquetBlockSize, useMetadataCache, parquetFilterPushDown, assumeBinaryIsString, @@ -287,7 +301,8 @@ private[sql] class ParquetRelation( followParquetFormatSpec) _ // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + val setInputPaths = + ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( @@ -475,11 +490,35 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + /** + * If parquet's block size (row group size) setting is larger than the min split size, + * we use parquet's block size setting as the min split size. Otherwise, we will create + * tasks processing nothing (because a split does not cover the starting point of a + * parquet block). See https://issues.apache.org/jira/browse/SPARK-10143 for more information. + */ + private def overrideMinSplitSize(parquetBlockSize: Long, conf: Configuration): Unit = { + val minSplitSize = + math.max( + conf.getLong("mapred.min.split.size", 0L), + conf.getLong("mapreduce.input.fileinputformat.split.minsize", 0L)) + if (parquetBlockSize > minSplitSize) { + val message = + s"Parquet's block size (row group size) is larger than " + + s"mapred.min.split.size/mapreduce.input.fileinputformat.split.minsize. Setting " + + s"mapred.min.split.size and mapreduce.input.fileinputformat.split.minsize to " + + s"$parquetBlockSize." + logDebug(message) + conf.set("mapred.min.split.size", parquetBlockSize.toString) + conf.set("mapreduce.input.fileinputformat.split.minsize", parquetBlockSize.toString) + } + } + /** This closure sets various Parquet configurations at both driver side and executor side. */ private[parquet] def initializeLocalJobFunc( requiredColumns: Array[String], filters: Array[Filter], dataSchema: StructType, + parquetBlockSize: Long, useMetadataCache: Boolean, parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, @@ -515,16 +554,21 @@ private[sql] object ParquetRelation extends Logging { conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + + overrideMinSplitSize(parquetBlockSize, conf) } /** This closure sets input paths at the driver side. */ private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { + inputFiles: Array[FileStatus], + parquetBlockSize: Long)(job: Job): Unit = { // We side the input paths at the driver side. logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") if (inputFiles.nonEmpty) { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } + + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( @@ -671,7 +715,7 @@ private[sql] object ParquetRelation extends Logging { val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - // HACK ALERT: + // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` @@ -752,38 +796,39 @@ private[sql] object ParquetRelation extends Logging { }.toOption } - def enableLogForwarding() { - // Note: the org.apache.parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "org.apache.parquet". This - // checks first to see if there's any handlers already set - // and if not it creates them. If this method executes prior - // to that class being loaded then: - // 1) there's no handlers installed so there's none to - // remove. But when it IS finally loaded the desired affect - // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHandlers(false) - // undoing the attempt to override the logging here. - // - // Therefore we need to force the class to be loaded. - // This should really be resolved by Parquet. - Utils.classForName(classOf[ParquetLog].getName) - - // Note: Logger.getLogger("parquet") has a default logger - // that appends to Console which needs to be cleared. - val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) - parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - parquetLogger.setUseParentHandlers(true) - - // Disables a WARN log message in ParquetOutputCommitter. We first ensure that - // ParquetOutputCommitter is loaded and the static LOG field gets initialized. - // See https://issues.apache.org/jira/browse/SPARK-5968 for details - Utils.classForName(classOf[ParquetOutputCommitter].getName) - JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) - - // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. - // See https://issues.apache.org/jira/browse/PARQUET-220 for details - Utils.classForName(classOf[ParquetRecordReader[_]].getName) - JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) + // JUL loggers must be held by a strong reference, otherwise they may get destroyed by GC. + // However, the root JUL logger used by Parquet isn't properly referenced. Here we keep + // references to loggers in both parquet-mr <= 1.6 and >= 1.7 + val apacheParquetLogger: JLogger = JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName) + val parquetLogger: JLogger = JLogger.getLogger("parquet") + + // Parquet initializes its own JUL logger in a static block which always prints to stdout. Here + // we redirect the JUL logger via SLF4J JUL bridge handler. + val redirectParquetLogsViaSLF4J: Unit = { + def redirect(logger: JLogger): Unit = { + logger.getHandlers.foreach(logger.removeHandler) + logger.setUseParentHandlers(false) + logger.addHandler(new SLF4JBridgeHandler) + } + + // For parquet-mr 1.7.0 and above versions, which are under `org.apache.parquet` namespace. + // scalastyle:off classforname + Class.forName(classOf[ApacheParquetLog].getName) + // scalastyle:on classforname + redirect(JLogger.getLogger(classOf[ApacheParquetLog].getPackage.getName)) + + // For parquet-mr 1.6.0 and lower versions bundled with Hive, which are under `parquet` + // namespace. + try { + // scalastyle:off classforname + Class.forName("parquet.Log") + // scalastyle:on classforname + redirect(JLogger.getLogger("parquet")) + } catch { case _: Throwable => + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly jar + // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block + // should be removed after this issue is fixed. + } } // The parquet compression short names diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala index 9cd0250f9c51..ed89aa27aa1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} @@ -52,7 +52,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } log.debug(s"write support initialized for requested schema $attributes") - ParquetRelation.enableLogForwarding() new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 3854f5bd39fb..42376ef7a9c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException @@ -104,7 +104,6 @@ private[parquet] object ParquetTypesConverter extends Logging { extraMetadata, "Spark") - ParquetRelation.enableLogForwarding() ParquetFileWriter.writeMetadataFile( conf, path, @@ -140,8 +139,6 @@ private[parquet] object ParquetTypesConverter extends Logging { (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE } - ParquetRelation.enableLogForwarding() - // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row // groups. Since Parquet schema is replicated among all row groups, we only need to touch a // single row group to read schema related metadata. Notice that we are making assumptions that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 40ca8bf4095d..16c9138419fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -116,6 +116,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray) + // Get all input data source relations of the query. val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src @@ -138,12 +140,12 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => + case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent.toSeq)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation) => @@ -153,7 +155,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableName that is also being read from.") + s"Cannot overwrite table $tableIdent that is also being read from.") } else { // OK } @@ -164,6 +166,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns) + case _ => // OK } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f7a68e4f5d44..2e108cb81451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.{InternalAccumulator, TaskContext} @@ -45,7 +46,10 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout @@ -65,6 +69,11 @@ case class BroadcastHashJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = buildSide match { + case BuildLeft => longMetric("numLeftRows") + case BuildRight => longMetric("numRightRows") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -73,8 +82,15 @@ case class BroadcastHashJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -85,6 +101,12 @@ case class BroadcastHashJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = buildSide match { + case BuildLeft => longMetric("numRightRows") + case BuildRight => longMetric("numLeftRows") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -95,7 +117,7 @@ case class BroadcastHashJoin( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashJoin(streamedIter, hashedRelation) + hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index a3626de49aea..69a8b95eaa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.{InternalAccumulator, TaskContext} /** @@ -45,6 +46,11 @@ case class BroadcastHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -63,6 +69,14 @@ case class BroadcastHashOuterJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = joinType match { + case RightOuter => longMetric("numLeftRows") + case LeftOuter => longMetric("numRightRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -71,8 +85,15 @@ case class BroadcastHashOuterJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -83,6 +104,15 @@ case class BroadcastHashOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = joinType match { + case RightOuter => longMetric("numRightRows") + case LeftOuter => longMetric("numLeftRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -101,16 +131,18 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) }) case RightOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 5bd06fbdca60..78a8c16c62bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,18 +38,31 @@ case class BroadcastLeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val input = right.execute().map(_.copy()).collect() + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val input = right.execute().map { row => + numRightRows += 1 + row.copy() + }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) + val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) + val hashRelation = + HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => @@ -59,7 +73,7 @@ case class BroadcastLeftSemiJoinHash( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashSemiJoin(streamIter, hashedRelation) + hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 23aebf4b068b..28c88b1b03d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer /** @@ -38,6 +39,11 @@ case class BroadcastNestedLoopJoin( condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + /** BuildRight means the right relation <=> the broadcast relation. */ private val (streamed, broadcast) = buildSide match { case BuildRight => (left, right) @@ -65,8 +71,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } @@ -74,9 +81,17 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val (numStreamedRows, numBuildRows) = buildSide match { + case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()) - .collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => @@ -93,6 +108,7 @@ case class BroadcastNestedLoopJoin( streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false + numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) @@ -161,6 +177,12 @@ case class BroadcastNestedLoopJoin( // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), + sparkContext.makeRDD(broadcastRowsWithNulls) + ).map { row => + // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. + numOutputRows += 1 + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 261b4724159f..2115f4070228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -30,13 +31,31 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().map { row => + numLeftRows += 1 + row.copy() + } + val rightResults = right.execute().map { row => + numRightRows += 1 + row.copy() + } leftResults.cartesian(rightResults).mapPartitions { iter => val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) + iter.map { r => + numOutputRows += 1 + joinedRow(r._1, r._2) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b..7ce4a517838c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashJoin { @@ -69,7 +70,9 @@ trait HashJoin { protected def hashJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ @@ -98,6 +101,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 + numOutputRows += 1 resultProjection(ret) } @@ -113,6 +117,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() + numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd8637..66903347c88c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer @DeveloperApi @@ -114,22 +115,28 @@ trait HashOuterJoin { key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } @@ -140,22 +147,28 @@ trait HashOuterJoin { key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } @@ -164,7 +177,7 @@ trait HashOuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -177,6 +190,7 @@ trait HashOuterJoin { // append them directly case (r, idx) if boundCondition(joinedRow.withRight(r)) => + numOutputRows += 1 matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) @@ -189,6 +203,7 @@ trait HashOuterJoin { // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. + numOutputRows += 1 joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { @@ -197,12 +212,15 @@ trait HashOuterJoin { // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[InternalRow] { l => + numOutputRows += 1 joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[InternalRow] { r => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } @@ -211,10 +229,12 @@ trait HashOuterJoin { // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], + numIterRows: LongSQLMetric, keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() + numIterRows += 1 val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed..beb141ade616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashSemiJoin { @@ -61,13 +62,15 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { val currentRow = buildIter.next() + numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -82,25 +85,35 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashSet: java.util.Set[InternalRow], + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { + numStreamRows += 1 val key = joinKeys(current) - !key.anyNull && hashSet.contains(key) + val r = !key.anyNull && hashSet.contains(key) + if (r) numOutputRows += 1 + r }) } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => + numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } + if (r) numOutputRows += 1 + r } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 953abf409f22..6c0196c21a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,9 +25,10 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -65,7 +66,8 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) @@ -87,7 +89,8 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) @@ -112,11 +115,13 @@ private[joins] object HashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + return UnsafeHashedRelation( + input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -130,6 +135,7 @@ private[joins] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() + numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -209,8 +215,10 @@ private[joins] final class UnsafeHashedRelation( if (binaryMap != null) { // Used in Broadcast join - val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes) + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() @@ -218,8 +226,8 @@ private[joins] final class UnsafeHashedRelation( var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength while (offset < last) { - val numFields = PlatformDependent.UNSAFE.getInt(base, offset) - val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + val numFields = Platform.getInt(base, offset) + val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 val row = new UnsafeRow @@ -239,40 +247,67 @@ private[joins] final class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 + if (binaryMap != null) { + // This could happen when a cached broadcast object need to be dumped into disk to free memory + out.writeInt(binaryMap.numElements()) + + var buffer = new Array[Byte](64) + def write(addr: MemoryLocation, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) + } + Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, + buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyAddress, loc.getKeyLength) + write(loc.getValueAddress, loc.getValueLength) + } + + } else { + assert(hashTable != null) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 } - out.write(values(i).getBytes) - i += 1 } } } @@ -295,7 +330,7 @@ private[joins] final class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, - nKeys * 2, // reduce hash collision + (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) var i = 0 @@ -314,10 +349,11 @@ private[joins] final class UnsafeHashedRelation( in.readFully(valuesBuffer, 0, valuesSize) // put it into binary map - val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, - valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + val putSuceeded = loc.putNewKey( + keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { throw new IOException("Could not allocate memory to grow BytesToBytesMap") } @@ -330,6 +366,7 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { @@ -339,6 +376,7 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 4443455ef11f..ad6362542f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -35,6 +36,11 @@ case class LeftSemiJoinBNL( extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = left.output @@ -52,13 +58,21 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numRightRows += 1 + row.copy() + }.collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { + numLeftRows += 1 var i = 0 var matched = false @@ -69,6 +83,9 @@ case class LeftSemiJoinBNL( } i += 1 } + if (matched) { + numOutputRows += 1 + } matched }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9..18808adaac63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,19 +38,28 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet) + val hashSet = buildKeyHashSet(buildIter, numRightRows) + hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation) + val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) + hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c923dc837c44..fc8c9439a6f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -38,7 +39,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) @@ -47,9 +51,15 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val (numBuildRows, numStreamedRows) = buildSide match { + case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) + val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) + hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 6a8c35efca8f..ed282f98b7d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -41,6 +42,11 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -53,39 +59,48 @@ case class ShuffledHashOuterJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) + val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection leftIter.flatMap( currentRow => { + numLeftRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) }) case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) + val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection rightIter.flatMap ( currentRow => { + numRightRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) }) case FullOuter => // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val leftHashTable = + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + val rightHashTable = + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) + joinedRow, + numOutputRows) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7..6b7322671d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -38,7 +38,10 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = left.output ++ right.output @@ -56,117 +59,276 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new RowIterator { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + numLeftRows, + RowIterator.fromScala(rightIter), + numRightRows + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) } else { - // no more result - throw new NoSuchElementException + identity[InternalRow] } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + } } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + numOutputRows += 1 + true } else { - rightElement = null + false } } - private def initialize() = { - fetchLeft() - fetchRight() + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + numStreamedRows: LongSQLMetric, + bufferedIter: RowIterator, + numBufferedRows: LongSQLMetric) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + + // Initialization (note: do _not_ want to advance streamed here). + advancedBufferedToRowWithNullFreeJoinKey() + + // --- Public methods --------------------------------------------------------------------------- + + def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (!advancedStreamed()) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + bufferedMatches.clear() + if (bufferedRow != null && !streamedRowKey.anyNull) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input then we always return true + true } } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + numStreamedRows += 1 + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + numBufferedRows += 1 + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + bufferedMatches.clear() + do { + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 000000000000..dea9e5e580a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} + +/** + * :: DeveloperApi :: + * Performs an sort merge outer join of two child relations. + * + * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. + */ +@DeveloperApi +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } + + override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. + keys.map(SortOrder(_, Ascending)) + } + + private def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + private def createLeftKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + private def createRightKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + + override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } + val resultProj: InternalRow => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + joinType match { + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + numLeftRows, + bufferedIter = RowIterator.fromScala(rightIter), + numRightRows + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + numRightRows, + bufferedIter = RowIterator.fromScala(leftIter), + numLeftRows + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} + + +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceLeft(): Boolean = { + rightIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withLeft(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching right rows, so return nulls for the right row + joinedRow.withRight(rightNullRow) + } else { + // Find the next row from the right input that satisfied the bound condition + if (!advanceRightUntilBoundConditionSatisfied()) { + joinedRow.withRight(rightNullRow) + } + } + true + } else { + // Left input has been exhausted + false + } + } + + private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceRight(): Boolean = { + leftIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withRight(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching left rows, so return nulls for the left row + joinedRow.withLeft(leftNullRow) + } else { + // Find the next row from the left input that satisfied the bound condition + if (!advanceLeftUntilBoundConditionSatisfied()) { + joinedRow.withLeft(leftNullRow) + } + } + true + } else { + // Right input has been exhausted + false + } + } + + private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) + leftIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala new file mode 100644 index 000000000000..a485a1a1d7ae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -0,0 +1,47 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + + +case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { + + private[this] var predicate: (InternalRow) => Boolean = _ + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = { + child.open() + predicate = GeneratePredicate.generate(condition, child.output) + } + + override def next(): Boolean = { + var found = false + while (child.next() && !found) { + found = predicate.apply(child.get()) + } + found + } + + override def get(): InternalRow = child.get() + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala new file mode 100644 index 000000000000..341c81438e6d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.StructType + +/** + * A local physical operator, in the form of an iterator. + * + * Before consuming the iterator, open function must be called. + * After consuming the iterator, close function must be called. + */ +abstract class LocalNode extends TreeNode[LocalNode] { + + def output: Seq[Attribute] + + /** + * Initializes the iterator state. Must be called before calling `next()`. + * + * Implementations of this must also call the `open()` function of its children. + */ + def open(): Unit + + /** + * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. + */ + def next(): Boolean + + /** + * Returns the current tuple. + */ + def get(): InternalRow + + /** + * Closes the iterator and releases all resources. + * + * Implementations of this must also call the `close()` function of its children. + */ + def close(): Unit + + /** + * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. + */ + def collect(): Seq[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) + val result = new scala.collection.mutable.ArrayBuffer[Row] + open() + while (next()) { + result += converter.apply(get()).asInstanceOf[Row] + } + close() + result + } +} + + +abstract class LeafLocalNode extends LocalNode { + override def children: Seq[LocalNode] = Seq.empty +} + + +abstract class UnaryLocalNode extends LocalNode { + + def child: LocalNode + + override def children: Seq[LocalNode] = Seq(child) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala new file mode 100644 index 000000000000..e574d1473cdc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -0,0 +1,42 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} + + +case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { + + private[this] var project: UnsafeProjection = _ + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override def open(): Unit = { + project = UnsafeProjection.create(projectList, child.output) + child.open() + } + + override def next(): Boolean = child.next() + + override def get(): InternalRow = { + project.apply(child.get()) + } + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala new file mode 100644 index 000000000000..994de8afa9a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * An operator that scans some local data collection in the form of Scala Seq. + */ +case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { + + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + iterator = data.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + // Do nothing + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 3b907e5da789..7a2a98ec18cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} @@ -93,22 +93,6 @@ private[sql] class LongSQLMetric private[metric](name: String) } } -/** - * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class IntSQLMetric private[metric](name: String) - extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { - - override def +=(term: Int): Unit = { - localValue.add(term) - } - - override def add(term: Int): Unit = { - localValue.add(term) - } -} - private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) @@ -121,29 +105,17 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) } -private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { - - override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) - - override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero - - override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) -} - private[sql] object SQLMetrics { - def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { - val acc = new IntSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { val acc = new LongSQLMetric(name) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } + + /** + * A metric that its value will be ignored. Use this one when we need a metric parameter but don't + * care about the value. + */ + val nullLongMetric = new LongSQLMetric("null") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8f1314..28fa231e722d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,6 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index e31693047012..40ef7c3b5353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -122,7 +122,6 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output - val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes /** * Set up the sorter in each partition before computing the parent partition. @@ -143,6 +142,7 @@ case class TungstenSort( } } + val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index cb7ca60b2fe4..49646a99d68c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 52ddf99e9266..f0b56c2eb7a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 2fd4fc658d06..5779c71f64e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import scala.collection.mutable @@ -26,7 +26,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -51,17 +51,14 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - @VisibleForTesting def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { _executionIdToData.toMap } - @VisibleForTesting def jobIdToExecutionId: Map[Long, Long] = synchronized { _jobIdToExecutionId.toMap } - @VisibleForTesting def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { _stageIdToStageMetrics.toMap } @@ -165,7 +162,7 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. } else if (stageAttemptID > stageMetrics.stageAttemptId) { logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + - s"what we have seen (${stageMetrics.stageAttemptId}})") + s"what we have seen (${stageMetrics.stageAttemptId})") } else { // TODO We don't know the attemptId. Currently, what we can do is overriding the // accumulator updates. However, if there are two same task are running, such as diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 3bba0afaf14e..0b0867f67eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger @@ -38,12 +38,12 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" private val nextTabId = new AtomicInteger(0) private def nextTabName: String = { val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL${nextId}" + if (nextId == 0) "SQL" else s"SQL$nextId" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1ba50b95becc..ae3d752dde34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 5180871585f2..258afadc7695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} @@ -26,7 +25,7 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * The abstract class for implementing user-defined aggregate functions. + * The base class for implementing user-defined aggregate functions (UDAF). */ @Experimental abstract class UserDefinedAggregateFunction extends Serializable { @@ -67,22 +66,35 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. */ - def returnDataType: DataType + def dataType: DataType - /** Indicates if this function is deterministic. */ + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ def deterministic: Boolean /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer - * still store initial values. + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. */ def initialize(buffer: MutableAggregationBuffer): Unit - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ def update(buffer: MutableAggregationBuffer, input: Row): Unit - /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** @@ -92,7 +104,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { def evaluate(buffer: Row): Any /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ @scala.annotation.varargs def apply(exprs: Column*): Column = { @@ -105,16 +117,16 @@ abstract class UserDefinedAggregateFunction extends Serializable { } /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. - * If `isDistinct` is true, this UDAF is working on distinct input values. + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. */ @scala.annotation.varargs - def apply(isDistinct: Boolean, exprs: Column*): Column = { + def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression2( ScalaUDAF(exprs.map(_.expr), this), Complete, - isDistinct = isDistinct) + isDistinct = true) Column(aggregateExpression) } } @@ -122,9 +134,11 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * :: Experimental :: * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. */ @Experimental -trait MutableAggregationBuffer extends Row { +abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 79c5f596661d..435e6319a64c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1863,14 +1863,15 @@ object functions { def substring_index(str: Column, delim: String, count: Int): Column = SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) - /* Translate any character in the src by a character in replaceString. - * The characters in replaceString is corresponding to the characters in matchingString. - * The translate will happen when any character in the string matching with the character - * in the matchingString. - * - * @group string_funcs - * @since 1.5.0 - */ + /** + * Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ def translate(src: Column, matchingString: String, replaceString: String): Column = StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e0510080f..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4d942e4f9287..3780cbbcc963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -36,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6bcabbab4f77..b3b326fe612c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ @@ -43,19 +43,24 @@ import org.apache.spark.util.SerializableConfiguration * This allows users to give the data source alias as the format type over the fully qualified * class name. * - * ex: parquet.DefaultSource.format = "parquet". - * * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 */ @DeveloperApi trait DataSourceRegister { /** * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source, - * ex: override def format(): String = "parquet" + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 */ - def format(): String + def shortName(): String } /** @@ -401,7 +406,7 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with Logging { + extends BaseRelation with FileRelation with Logging { override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") @@ -511,6 +516,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def paths: Array[String] + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md index 3dd9861b4896..421c2ea4f7ae 100644 --- a/sql/core/src/test/README.md +++ b/sql/core/src/test/README.md @@ -6,23 +6,19 @@ The following directories and files are used for Parquet compatibility tests: . ├── README.md # This file ├── avro -│   ├── parquet-compat.avdl # Testing Avro IDL -│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ├── gen-java # !! NO TOUCH !! Generated Java code ├── scripts -│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift └── thrift - └── parquet-compat.thrift # Testing Thrift schema + └── *.thrift # Testing Thrift schema(s) ``` -Generated Java code are used in the following test suites: - -- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` -- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` - To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. -When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. ## Prerequisites diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 24729f6143e6..c5eb5b5164cf 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -16,14 +16,25 @@ */ // This is a test protocol for testing parquet-avro compatibility. -@namespace("org.apache.spark.sql.parquet.test.avro") +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + record Nested { array nested_ints_column; string nested_string_column; } - record ParquetAvroCompat { + record AvroPrimitives { boolean bool_column; int int_column; long long_column; @@ -31,7 +42,9 @@ protocol CompatibilityTest { double double_column; bytes binary_column; string string_column; + } + record AvroOptionalPrimitives { union { null, boolean } maybe_bool_column; union { null, int } maybe_int_column; union { null, long } maybe_long_column; @@ -39,7 +52,22 @@ protocol CompatibilityTest { union { null, double } maybe_double_column; union { null, bytes } maybe_binary_column; union { null, string } maybe_string_column; + } + record AvroNonNullableArrays { + array strings_column; + union { null, array } maybe_ints_column; + } + + record AvroArrayOfArray { + array> int_arrays_column; + } + + record AvroMapOfArray { + map> string_to_ints_column; + } + + record ParquetAvroCompat { array strings_column; map string_to_int_column; map> complex_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index a83b7c990dd2..9ad315b74fb4 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -1,7 +1,18 @@ { "protocol" : "CompatibilityTest", - "namespace" : "org.apache.spark.sql.parquet.test.avro", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { "type" : "record", "name" : "Nested", "fields" : [ { @@ -16,7 +27,7 @@ } ] }, { "type" : "record", - "name" : "ParquetAvroCompat", + "name" : "AvroPrimitives", "fields" : [ { "name" : "bool_column", "type" : "boolean" @@ -38,7 +49,11 @@ }, { "name" : "string_column", "type" : "string" - }, { + } ] + }, { + "type" : "record", + "name" : "AvroOptionalPrimitives", + "fields" : [ { "name" : "maybe_bool_column", "type" : [ "null", "boolean" ] }, { @@ -59,7 +74,53 @@ }, { "name" : "maybe_string_column", "type" : [ "null", "string" ] + } ] + }, { + "type" : "record", + "name" : "AvroNonNullableArrays", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } }, { + "name" : "maybe_ints_column", + "type" : [ "null", { + "type" : "array", + "items" : "int" + } ] + } ] + }, { + "type" : "record", + "name" : "AvroArrayOfArray", + "fields" : [ { + "name" : "int_arrays_column", + "type" : { + "type" : "array", + "items" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "AvroMapOfArray", + "fields" : [ { + "name" : "string_to_ints_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { "name" : "strings_column", "type" : { "type" : "array", diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java new file mode 100644 index 000000000000..ee327827903e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroArrayOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List> int_arrays_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroArrayOfArray() {} + + /** + * All-args constructor. + */ + public AvroArrayOfArray(java.util.List> int_arrays_column) { + this.int_arrays_column = int_arrays_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return int_arrays_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: int_arrays_column = (java.util.List>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'int_arrays_column' field. + */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** + * Sets the value of the 'int_arrays_column' field. + * @param value the value to set. + */ + public void setIntArraysColumn(java.util.List> value) { + this.int_arrays_column = value; + } + + /** Creates a new AvroArrayOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing AvroArrayOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroArrayOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List> int_arrays_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroArrayOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'int_arrays_column' field */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** Sets the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder setIntArraysColumn(java.util.List> value) { + validate(fields()[0], value); + this.int_arrays_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'int_arrays_column' field has been set */ + public boolean hasIntArraysColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder clearIntArraysColumn() { + int_arrays_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroArrayOfArray build() { + try { + AvroArrayOfArray record = new AvroArrayOfArray(); + record.int_arrays_column = fieldSetFlags()[0] ? this.int_arrays_column : (java.util.List>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java new file mode 100644 index 000000000000..727f6a7bf733 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroMapOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.Map> string_to_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroMapOfArray() {} + + /** + * All-args constructor. + */ + public AvroMapOfArray(java.util.Map> string_to_ints_column) { + this.string_to_ints_column = string_to_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return string_to_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: string_to_ints_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'string_to_ints_column' field. + */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** + * Sets the value of the 'string_to_ints_column' field. + * @param value the value to set. + */ + public void setStringToIntsColumn(java.util.Map> value) { + this.string_to_ints_column = value; + } + + /** Creates a new AvroMapOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing AvroMapOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroMapOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.Map> string_to_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroMapOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'string_to_ints_column' field */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** Sets the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder setStringToIntsColumn(java.util.Map> value) { + validate(fields()[0], value); + this.string_to_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'string_to_ints_column' field has been set */ + public boolean hasStringToIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder clearStringToIntsColumn() { + string_to_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroMapOfArray build() { + try { + AvroMapOfArray record = new AvroMapOfArray(); + record.string_to_ints_column = fieldSetFlags()[0] ? this.string_to_ints_column : (java.util.Map>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java new file mode 100644 index 000000000000..934793f42f9c --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroNonNullableArrays extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.List maybe_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroNonNullableArrays() {} + + /** + * All-args constructor. + */ + public AvroNonNullableArrays(java.util.List strings_column, java.util.List maybe_ints_column) { + this.strings_column = strings_column; + this.maybe_ints_column = maybe_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return maybe_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: maybe_ints_column = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'maybe_ints_column' field. + */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** + * Sets the value of the 'maybe_ints_column' field. + * @param value the value to set. + */ + public void setMaybeIntsColumn(java.util.List value) { + this.maybe_ints_column = value; + } + + /** Creates a new AvroNonNullableArrays RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing AvroNonNullableArrays instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** + * RecordBuilder for AvroNonNullableArrays instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.List maybe_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing AvroNonNullableArrays instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_ints_column' field */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** Sets the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setMaybeIntsColumn(java.util.List value) { + validate(fields()[1], value); + this.maybe_ints_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_ints_column' field has been set */ + public boolean hasMaybeIntsColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearMaybeIntsColumn() { + maybe_ints_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public AvroNonNullableArrays build() { + try { + AvroNonNullableArrays record = new AvroNonNullableArrays(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.maybe_ints_column = fieldSetFlags()[1] ? this.maybe_ints_column : (java.util.List) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java new file mode 100644 index 000000000000..e4d1ead8dd15 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java @@ -0,0 +1,466 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroOptionalPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroOptionalPrimitives() {} + + /** + * All-args constructor. + */ + public AvroOptionalPrimitives(java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column) { + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return maybe_bool_column; + case 1: return maybe_int_column; + case 2: return maybe_long_column; + case 3: return maybe_float_column; + case 4: return maybe_double_column; + case 5: return maybe_binary_column; + case 6: return maybe_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: maybe_bool_column = (java.lang.Boolean)value$; break; + case 1: maybe_int_column = (java.lang.Integer)value$; break; + case 2: maybe_long_column = (java.lang.Long)value$; break; + case 3: maybe_float_column = (java.lang.Float)value$; break; + case 4: maybe_double_column = (java.lang.Double)value$; break; + case 5: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 6: maybe_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing AvroOptionalPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroOptionalPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroOptionalPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[0], value); + this.maybe_bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[1], value); + this.maybe_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[2], value); + this.maybe_long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[3], value); + this.maybe_float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[4], value); + this.maybe_double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.maybe_binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.maybe_string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroOptionalPrimitives build() { + try { + AvroOptionalPrimitives record = new AvroOptionalPrimitives(); + record.maybe_bool_column = fieldSetFlags()[0] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.maybe_int_column = fieldSetFlags()[1] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.maybe_long_column = fieldSetFlags()[2] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[2]); + record.maybe_float_column = fieldSetFlags()[3] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[3]); + record.maybe_double_column = fieldSetFlags()[4] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[4]); + record.maybe_binary_column = fieldSetFlags()[5] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.maybe_string_column = fieldSetFlags()[6] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java new file mode 100644 index 000000000000..1c2afed16781 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java @@ -0,0 +1,461 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroPrimitives() {} + + /** + * All-args constructor. + */ + public AvroPrimitives(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** Creates a new AvroPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing AvroPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroPrimitives build() { + try { + AvroPrimitives record = new AvroPrimitives(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java new file mode 100644 index 000000000000..28fdc1dfb911 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -0,0 +1,17 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; + +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public interface CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]},{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]},{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]},{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + + @SuppressWarnings("all") + public interface Callback extends CompatibilityTest { + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; + } +} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java similarity index 75% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index 051f1ee90386..a7bf4841919c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public java.util.List nested_ints_column; @Deprecated public java.lang.String nested_string_column; @@ -77,18 +77,18 @@ public void setNestedStringColumn(java.lang.String value) { } /** Creates a new Nested RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); } /** Creates a new Nested RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** @@ -102,11 +102,11 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { super(other); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); @@ -119,8 +119,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { } /** Creates a Builder by copying an existing Nested instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); fieldSetFlags()[0] = true; @@ -137,7 +137,7 @@ public java.util.List getNestedIntsColumn() { } /** Sets the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { validate(fields()[0], value); this.nested_ints_column = value; fieldSetFlags()[0] = true; @@ -150,7 +150,7 @@ public boolean hasNestedIntsColumn() { } /** Clears the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { nested_ints_column = null; fieldSetFlags()[0] = false; return this; @@ -162,7 +162,7 @@ public java.lang.String getNestedStringColumn() { } /** Sets the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { validate(fields()[1], value); this.nested_string_column = value; fieldSetFlags()[1] = true; @@ -175,7 +175,7 @@ public boolean hasNestedStringColumn() { } /** Clears the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { nested_string_column = null; fieldSetFlags()[1] = false; return this; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java new file mode 100644 index 000000000000..ef12d193f916 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -0,0 +1,250 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.Map string_to_int_column; + @Deprecated public java.util.Map> complex_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetAvroCompat() {} + + /** + * All-args constructor. + */ + public ParquetAvroCompat(java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + this.strings_column = strings_column; + this.string_to_int_column = string_to_int_column; + this.complex_column = complex_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return string_to_int_column; + case 2: return complex_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: string_to_int_column = (java.util.Map)value$; break; + case 2: complex_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'string_to_int_column' field. + */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** + * Sets the value of the 'string_to_int_column' field. + * @param value the value to set. + */ + public void setStringToIntColumn(java.util.Map value) { + this.string_to_int_column = value; + } + + /** + * Gets the value of the 'complex_column' field. + */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** + * Sets the value of the 'complex_column' field. + * @param value the value to set. + */ + public void setComplexColumn(java.util.Map> value) { + this.complex_column = value; + } + + /** Creates a new ParquetAvroCompat RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); + } + + /** + * RecordBuilder for ParquetAvroCompat instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.Map string_to_int_column; + private java.util.Map> complex_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Creates a Builder by copying an existing ParquetAvroCompat instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); + fieldSetFlags()[2] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'string_to_int_column' field */ + public java.util.Map getStringToIntColumn() { + return string_to_int_column; + } + + /** Sets the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + validate(fields()[1], value); + this.string_to_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'string_to_int_column' field has been set */ + public boolean hasStringToIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'string_to_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + string_to_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'complex_column' field */ + public java.util.Map> getComplexColumn() { + return complex_column; + } + + /** Sets the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + validate(fields()[2], value); + this.complex_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'complex_column' field has been set */ + public boolean hasComplexColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'complex_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + complex_column = null; + fieldSetFlags()[2] = false; + return this; + } + + @Override + public ParquetAvroCompat build() { + try { + ParquetAvroCompat record = new ParquetAvroCompat(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.string_to_int_column = fieldSetFlags()[1] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[1]); + record.complex_column = fieldSetFlags()[2] ? this.complex_column : (java.util.Map>) defaultValue(fields()[2]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 000000000000..05fefe4cee75 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 000000000000..00711a0c2a26 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java deleted file mode 100644 index daec65a5bbe5..000000000000 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Autogenerated by Avro - * - * DO NOT EDIT DIRECTLY - */ -package org.apache.spark.sql.parquet.test.avro; - -@SuppressWarnings("all") -@org.apache.avro.specific.AvroGenerated -public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); - - @SuppressWarnings("all") - public interface Callback extends CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; - } -} \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java deleted file mode 100644 index 354c9d73cca3..000000000000 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java +++ /dev/null @@ -1,1001 +0,0 @@ -/** - * Autogenerated by Avro - * - * DO NOT EDIT DIRECTLY - */ -package org.apache.spark.sql.parquet.test.avro; -@SuppressWarnings("all") -@org.apache.avro.specific.AvroGenerated -public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); - public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } - @Deprecated public boolean bool_column; - @Deprecated public int int_column; - @Deprecated public long long_column; - @Deprecated public float float_column; - @Deprecated public double double_column; - @Deprecated public java.nio.ByteBuffer binary_column; - @Deprecated public java.lang.String string_column; - @Deprecated public java.lang.Boolean maybe_bool_column; - @Deprecated public java.lang.Integer maybe_int_column; - @Deprecated public java.lang.Long maybe_long_column; - @Deprecated public java.lang.Float maybe_float_column; - @Deprecated public java.lang.Double maybe_double_column; - @Deprecated public java.nio.ByteBuffer maybe_binary_column; - @Deprecated public java.lang.String maybe_string_column; - @Deprecated public java.util.List strings_column; - @Deprecated public java.util.Map string_to_int_column; - @Deprecated public java.util.Map> complex_column; - - /** - * Default constructor. Note that this does not initialize fields - * to their default values from the schema. If that is desired then - * one should use newBuilder(). - */ - public ParquetAvroCompat() {} - - /** - * All-args constructor. - */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { - this.bool_column = bool_column; - this.int_column = int_column; - this.long_column = long_column; - this.float_column = float_column; - this.double_column = double_column; - this.binary_column = binary_column; - this.string_column = string_column; - this.maybe_bool_column = maybe_bool_column; - this.maybe_int_column = maybe_int_column; - this.maybe_long_column = maybe_long_column; - this.maybe_float_column = maybe_float_column; - this.maybe_double_column = maybe_double_column; - this.maybe_binary_column = maybe_binary_column; - this.maybe_string_column = maybe_string_column; - this.strings_column = strings_column; - this.string_to_int_column = string_to_int_column; - this.complex_column = complex_column; - } - - public org.apache.avro.Schema getSchema() { return SCHEMA$; } - // Used by DatumWriter. Applications should not call. - public java.lang.Object get(int field$) { - switch (field$) { - case 0: return bool_column; - case 1: return int_column; - case 2: return long_column; - case 3: return float_column; - case 4: return double_column; - case 5: return binary_column; - case 6: return string_column; - case 7: return maybe_bool_column; - case 8: return maybe_int_column; - case 9: return maybe_long_column; - case 10: return maybe_float_column; - case 11: return maybe_double_column; - case 12: return maybe_binary_column; - case 13: return maybe_string_column; - case 14: return strings_column; - case 15: return string_to_int_column; - case 16: return complex_column; - default: throw new org.apache.avro.AvroRuntimeException("Bad index"); - } - } - // Used by DatumReader. Applications should not call. - @SuppressWarnings(value="unchecked") - public void put(int field$, java.lang.Object value$) { - switch (field$) { - case 0: bool_column = (java.lang.Boolean)value$; break; - case 1: int_column = (java.lang.Integer)value$; break; - case 2: long_column = (java.lang.Long)value$; break; - case 3: float_column = (java.lang.Float)value$; break; - case 4: double_column = (java.lang.Double)value$; break; - case 5: binary_column = (java.nio.ByteBuffer)value$; break; - case 6: string_column = (java.lang.String)value$; break; - case 7: maybe_bool_column = (java.lang.Boolean)value$; break; - case 8: maybe_int_column = (java.lang.Integer)value$; break; - case 9: maybe_long_column = (java.lang.Long)value$; break; - case 10: maybe_float_column = (java.lang.Float)value$; break; - case 11: maybe_double_column = (java.lang.Double)value$; break; - case 12: maybe_binary_column = (java.nio.ByteBuffer)value$; break; - case 13: maybe_string_column = (java.lang.String)value$; break; - case 14: strings_column = (java.util.List)value$; break; - case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; - default: throw new org.apache.avro.AvroRuntimeException("Bad index"); - } - } - - /** - * Gets the value of the 'bool_column' field. - */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** - * Sets the value of the 'bool_column' field. - * @param value the value to set. - */ - public void setBoolColumn(java.lang.Boolean value) { - this.bool_column = value; - } - - /** - * Gets the value of the 'int_column' field. - */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** - * Sets the value of the 'int_column' field. - * @param value the value to set. - */ - public void setIntColumn(java.lang.Integer value) { - this.int_column = value; - } - - /** - * Gets the value of the 'long_column' field. - */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** - * Sets the value of the 'long_column' field. - * @param value the value to set. - */ - public void setLongColumn(java.lang.Long value) { - this.long_column = value; - } - - /** - * Gets the value of the 'float_column' field. - */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** - * Sets the value of the 'float_column' field. - * @param value the value to set. - */ - public void setFloatColumn(java.lang.Float value) { - this.float_column = value; - } - - /** - * Gets the value of the 'double_column' field. - */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** - * Sets the value of the 'double_column' field. - * @param value the value to set. - */ - public void setDoubleColumn(java.lang.Double value) { - this.double_column = value; - } - - /** - * Gets the value of the 'binary_column' field. - */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** - * Sets the value of the 'binary_column' field. - * @param value the value to set. - */ - public void setBinaryColumn(java.nio.ByteBuffer value) { - this.binary_column = value; - } - - /** - * Gets the value of the 'string_column' field. - */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** - * Sets the value of the 'string_column' field. - * @param value the value to set. - */ - public void setStringColumn(java.lang.String value) { - this.string_column = value; - } - - /** - * Gets the value of the 'maybe_bool_column' field. - */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** - * Sets the value of the 'maybe_bool_column' field. - * @param value the value to set. - */ - public void setMaybeBoolColumn(java.lang.Boolean value) { - this.maybe_bool_column = value; - } - - /** - * Gets the value of the 'maybe_int_column' field. - */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** - * Sets the value of the 'maybe_int_column' field. - * @param value the value to set. - */ - public void setMaybeIntColumn(java.lang.Integer value) { - this.maybe_int_column = value; - } - - /** - * Gets the value of the 'maybe_long_column' field. - */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** - * Sets the value of the 'maybe_long_column' field. - * @param value the value to set. - */ - public void setMaybeLongColumn(java.lang.Long value) { - this.maybe_long_column = value; - } - - /** - * Gets the value of the 'maybe_float_column' field. - */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** - * Sets the value of the 'maybe_float_column' field. - * @param value the value to set. - */ - public void setMaybeFloatColumn(java.lang.Float value) { - this.maybe_float_column = value; - } - - /** - * Gets the value of the 'maybe_double_column' field. - */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** - * Sets the value of the 'maybe_double_column' field. - * @param value the value to set. - */ - public void setMaybeDoubleColumn(java.lang.Double value) { - this.maybe_double_column = value; - } - - /** - * Gets the value of the 'maybe_binary_column' field. - */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** - * Sets the value of the 'maybe_binary_column' field. - * @param value the value to set. - */ - public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { - this.maybe_binary_column = value; - } - - /** - * Gets the value of the 'maybe_string_column' field. - */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** - * Sets the value of the 'maybe_string_column' field. - * @param value the value to set. - */ - public void setMaybeStringColumn(java.lang.String value) { - this.maybe_string_column = value; - } - - /** - * Gets the value of the 'strings_column' field. - */ - public java.util.List getStringsColumn() { - return strings_column; - } - - /** - * Sets the value of the 'strings_column' field. - * @param value the value to set. - */ - public void setStringsColumn(java.util.List value) { - this.strings_column = value; - } - - /** - * Gets the value of the 'string_to_int_column' field. - */ - public java.util.Map getStringToIntColumn() { - return string_to_int_column; - } - - /** - * Sets the value of the 'string_to_int_column' field. - * @param value the value to set. - */ - public void setStringToIntColumn(java.util.Map value) { - this.string_to_int_column = value; - } - - /** - * Gets the value of the 'complex_column' field. - */ - public java.util.Map> getComplexColumn() { - return complex_column; - } - - /** - * Sets the value of the 'complex_column' field. - * @param value the value to set. - */ - public void setComplexColumn(java.util.Map> value) { - this.complex_column = value; - } - - /** Creates a new ParquetAvroCompat RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); - } - - /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); - } - - /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); - } - - /** - * RecordBuilder for ParquetAvroCompat instances. - */ - public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase - implements org.apache.avro.data.RecordBuilder { - - private boolean bool_column; - private int int_column; - private long long_column; - private float float_column; - private double double_column; - private java.nio.ByteBuffer binary_column; - private java.lang.String string_column; - private java.lang.Boolean maybe_bool_column; - private java.lang.Integer maybe_int_column; - private java.lang.Long maybe_long_column; - private java.lang.Float maybe_float_column; - private java.lang.Double maybe_double_column; - private java.nio.ByteBuffer maybe_binary_column; - private java.lang.String maybe_string_column; - private java.util.List strings_column; - private java.util.Map string_to_int_column; - private java.util.Map> complex_column; - - /** Creates a new Builder */ - private Builder() { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); - } - - /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - super(other); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); - fieldSetFlags()[0] = true; - } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); - fieldSetFlags()[1] = true; - } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); - fieldSetFlags()[2] = true; - } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } - } - - /** Creates a Builder by copying an existing ParquetAvroCompat instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); - fieldSetFlags()[0] = true; - } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); - fieldSetFlags()[1] = true; - } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); - fieldSetFlags()[2] = true; - } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } - } - - /** Gets the value of the 'bool_column' field */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { - validate(fields()[0], value); - this.bool_column = value; - fieldSetFlags()[0] = true; - return this; - } - - /** Checks whether the 'bool_column' field has been set */ - public boolean hasBoolColumn() { - return fieldSetFlags()[0]; - } - - /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { - fieldSetFlags()[0] = false; - return this; - } - - /** Gets the value of the 'int_column' field */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { - validate(fields()[1], value); - this.int_column = value; - fieldSetFlags()[1] = true; - return this; - } - - /** Checks whether the 'int_column' field has been set */ - public boolean hasIntColumn() { - return fieldSetFlags()[1]; - } - - /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { - fieldSetFlags()[1] = false; - return this; - } - - /** Gets the value of the 'long_column' field */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { - validate(fields()[2], value); - this.long_column = value; - fieldSetFlags()[2] = true; - return this; - } - - /** Checks whether the 'long_column' field has been set */ - public boolean hasLongColumn() { - return fieldSetFlags()[2]; - } - - /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { - fieldSetFlags()[2] = false; - return this; - } - - /** Gets the value of the 'float_column' field */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { - validate(fields()[3], value); - this.float_column = value; - fieldSetFlags()[3] = true; - return this; - } - - /** Checks whether the 'float_column' field has been set */ - public boolean hasFloatColumn() { - return fieldSetFlags()[3]; - } - - /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { - fieldSetFlags()[3] = false; - return this; - } - - /** Gets the value of the 'double_column' field */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { - validate(fields()[4], value); - this.double_column = value; - fieldSetFlags()[4] = true; - return this; - } - - /** Checks whether the 'double_column' field has been set */ - public boolean hasDoubleColumn() { - return fieldSetFlags()[4]; - } - - /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { - fieldSetFlags()[4] = false; - return this; - } - - /** Gets the value of the 'binary_column' field */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[5], value); - this.binary_column = value; - fieldSetFlags()[5] = true; - return this; - } - - /** Checks whether the 'binary_column' field has been set */ - public boolean hasBinaryColumn() { - return fieldSetFlags()[5]; - } - - /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { - binary_column = null; - fieldSetFlags()[5] = false; - return this; - } - - /** Gets the value of the 'string_column' field */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { - validate(fields()[6], value); - this.string_column = value; - fieldSetFlags()[6] = true; - return this; - } - - /** Checks whether the 'string_column' field has been set */ - public boolean hasStringColumn() { - return fieldSetFlags()[6]; - } - - /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { - string_column = null; - fieldSetFlags()[6] = false; - return this; - } - - /** Gets the value of the 'maybe_bool_column' field */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { - validate(fields()[7], value); - this.maybe_bool_column = value; - fieldSetFlags()[7] = true; - return this; - } - - /** Checks whether the 'maybe_bool_column' field has been set */ - public boolean hasMaybeBoolColumn() { - return fieldSetFlags()[7]; - } - - /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { - maybe_bool_column = null; - fieldSetFlags()[7] = false; - return this; - } - - /** Gets the value of the 'maybe_int_column' field */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { - validate(fields()[8], value); - this.maybe_int_column = value; - fieldSetFlags()[8] = true; - return this; - } - - /** Checks whether the 'maybe_int_column' field has been set */ - public boolean hasMaybeIntColumn() { - return fieldSetFlags()[8]; - } - - /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { - maybe_int_column = null; - fieldSetFlags()[8] = false; - return this; - } - - /** Gets the value of the 'maybe_long_column' field */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { - validate(fields()[9], value); - this.maybe_long_column = value; - fieldSetFlags()[9] = true; - return this; - } - - /** Checks whether the 'maybe_long_column' field has been set */ - public boolean hasMaybeLongColumn() { - return fieldSetFlags()[9]; - } - - /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { - maybe_long_column = null; - fieldSetFlags()[9] = false; - return this; - } - - /** Gets the value of the 'maybe_float_column' field */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { - validate(fields()[10], value); - this.maybe_float_column = value; - fieldSetFlags()[10] = true; - return this; - } - - /** Checks whether the 'maybe_float_column' field has been set */ - public boolean hasMaybeFloatColumn() { - return fieldSetFlags()[10]; - } - - /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { - maybe_float_column = null; - fieldSetFlags()[10] = false; - return this; - } - - /** Gets the value of the 'maybe_double_column' field */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { - validate(fields()[11], value); - this.maybe_double_column = value; - fieldSetFlags()[11] = true; - return this; - } - - /** Checks whether the 'maybe_double_column' field has been set */ - public boolean hasMaybeDoubleColumn() { - return fieldSetFlags()[11]; - } - - /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { - maybe_double_column = null; - fieldSetFlags()[11] = false; - return this; - } - - /** Gets the value of the 'maybe_binary_column' field */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[12], value); - this.maybe_binary_column = value; - fieldSetFlags()[12] = true; - return this; - } - - /** Checks whether the 'maybe_binary_column' field has been set */ - public boolean hasMaybeBinaryColumn() { - return fieldSetFlags()[12]; - } - - /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { - maybe_binary_column = null; - fieldSetFlags()[12] = false; - return this; - } - - /** Gets the value of the 'maybe_string_column' field */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { - validate(fields()[13], value); - this.maybe_string_column = value; - fieldSetFlags()[13] = true; - return this; - } - - /** Checks whether the 'maybe_string_column' field has been set */ - public boolean hasMaybeStringColumn() { - return fieldSetFlags()[13]; - } - - /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { - maybe_string_column = null; - fieldSetFlags()[13] = false; - return this; - } - - /** Gets the value of the 'strings_column' field */ - public java.util.List getStringsColumn() { - return strings_column; - } - - /** Sets the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { - validate(fields()[14], value); - this.strings_column = value; - fieldSetFlags()[14] = true; - return this; - } - - /** Checks whether the 'strings_column' field has been set */ - public boolean hasStringsColumn() { - return fieldSetFlags()[14]; - } - - /** Clears the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { - strings_column = null; - fieldSetFlags()[14] = false; - return this; - } - - /** Gets the value of the 'string_to_int_column' field */ - public java.util.Map getStringToIntColumn() { - return string_to_int_column; - } - - /** Sets the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { - validate(fields()[15], value); - this.string_to_int_column = value; - fieldSetFlags()[15] = true; - return this; - } - - /** Checks whether the 'string_to_int_column' field has been set */ - public boolean hasStringToIntColumn() { - return fieldSetFlags()[15]; - } - - /** Clears the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { - string_to_int_column = null; - fieldSetFlags()[15] = false; - return this; - } - - /** Gets the value of the 'complex_column' field */ - public java.util.Map> getComplexColumn() { - return complex_column; - } - - /** Sets the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { - validate(fields()[16], value); - this.complex_column = value; - fieldSetFlags()[16] = true; - return this; - } - - /** Checks whether the 'complex_column' field has been set */ - public boolean hasComplexColumn() { - return fieldSetFlags()[16]; - } - - /** Clears the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { - complex_column = null; - fieldSetFlags()[16] = false; - return this; - } - - @Override - public ParquetAvroCompat build() { - try { - ParquetAvroCompat record = new ParquetAvroCompat(); - record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); - record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); - record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); - record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); - record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); - record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); - record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); - record.maybe_bool_column = fieldSetFlags()[7] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[7]); - record.maybe_int_column = fieldSetFlags()[8] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[8]); - record.maybe_long_column = fieldSetFlags()[9] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[9]); - record.maybe_float_column = fieldSetFlags()[10] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[10]); - record.maybe_double_column = fieldSetFlags()[11] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[11]); - record.maybe_binary_column = fieldSetFlags()[12] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[12]); - record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); - record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); - record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); - return record; - } catch (Exception e) { - throw new org.apache.avro.AvroRuntimeException(e); - } - } - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index e912eb835d16..bf693c7c393f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +35,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7302361ab9fd..7abdd3db8034 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,44 +17,45 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConversions; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test @@ -230,7 +231,7 @@ public void testCovariance() { @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff37..bb02b58cca9b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -23,12 +23,12 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28a..6f9e7f68dc39 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -52,8 +53,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -71,6 +73,13 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { Map options = new HashMap(); diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/sql/core/src/test/resources/nested-array-struct.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 000000000000..520922f73ebb Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-int.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/old-repeated-message.parquet new file mode 100644 index 000000000000..548db9916277 Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-message.parquet differ diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 000000000000..213f1a90291b Binary files /dev/null and b/sql/core/src/test/resources/old-repeated.parquet differ diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet old mode 100755 new mode 100644 diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-string.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 000000000000..c29eee35c350 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-struct.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/proto-struct-with-array-many.parquet new file mode 100644 index 000000000000..ff9809675fc0 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array-many.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/proto-struct-with-array.parquet new file mode 100644 index 000000000000..325a8370ad20 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a88df91b1001..af7590c3d3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,24 +18,20 @@ package org.apache.spark.sql import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6a09a3b72c08..37738ec5b3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,20 +17,25 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -106,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( @@ -258,7 +271,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -268,7 +281,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } @@ -289,7 +302,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } @@ -309,7 +322,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -433,13 +446,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) - test("&&") { checkAnswer( booleanData.filter($"a" && true), @@ -523,7 +529,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -544,7 +550,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff7440a76..72cf7aab0b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 000000000000..3c359dd840ab --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03116a374f3b..9d965258e389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a494..e5d7d63441a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c706242d..e2716d7841d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44ee2c7..cdaa14ac8078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f5984e4a8ce..28bdd6f83b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString test("sample with replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest { test("sample without replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest { } test("Frequent Items 2") { - val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest { } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c49f256be550..284fff184085 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,18 +23,12 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. @@ -485,21 +479,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("inputFiles") { - val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), - Some(testData.schema), None, Map.empty)(sqlContext) - val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) - assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") - val fakeRelation2 = new JSONRelation( - None, 1, Some(testData.schema), None, None, Array("/json/path"))(sqlContext) - val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) - assert(df2.inputFiles.toSet == fakeRelation2.paths.toSet) + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) - val unionDF = df1.unionAll(df2) - assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) - val filtered = df1.filter("false").unionAll(df2.intersect(df2)) - assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } } ignore("show") { @@ -871,4 +867,24 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) assert(expected === actual) } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } + + test("SPARK-9950: correctly analyze grouping/aggregating on struct fields") { + val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") + checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) + } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a97bc6..77907e91363e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,10 +27,8 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897caf952a..9080c53c491a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,19 +22,18 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } @@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( 0).getTime - System.currentTimeMillis()) < 5000) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d896603..f5c5046a8ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,32 +17,26 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -55,6 +49,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -66,7 +61,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,11 +77,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", @@ -97,90 +91,83 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + sql("CACHE TABLE testData") + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } } - - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + sql("CACHE TABLE testData") + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - - ctx.sql("UNCACHE TABLE testData") + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -285,7 +272,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -299,7 +286,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -345,7 +332,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -354,7 +341,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -406,7 +393,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -415,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -430,7 +417,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -445,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -456,31 +443,30 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold + sql("CACHE TABLE testData") - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: @@ -488,6 +474,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 71c26a6f8d36..045fea82e4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class JsonFunctionsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function get_json_object") { val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf..babf8835d254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5957d8..30289c3c1d09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } @@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..4adcefb7dc4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -71,12 +71,6 @@ class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7865d6..795d4e983f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c2..7699adadd9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(ctx.getConf(testKey, testVal + "_") === testVal) assert(ctx.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") + sql("set some.property=20") assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") + sql("set some.property = 40") assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") + sql(s"set $key=$vs") assert(ctx.getConf(key, "0") === vs) - ctx.sql(s"set $key=") + sql(s"set $key=") assert(ctx.getConf(key, "0") === "") ctx.conf.clear() @@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(ctx.conf.numShufflePartitions === 10) } test("invalid conf value") { ctx.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a4..007be1295077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b14ef9bab90c..dcb4e8371098 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,30 +17,26 @@ package org.apache.spark.sql +import java.math.MathContext import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -60,7 +56,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { @@ -178,7 +175,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) @@ -712,9 +709,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -1161,7 +1156,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1613,6 +1609,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + test("decimal precision with multiply/division") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(null)) + } + test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { val sc = sqlContext.sparkContext @@ -1623,12 +1637,53 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-9511: error with table starting with number") { - val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) - .toDF("num", "str") - df.registerTempTable("1one") + withTempTable("1one") { + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + .registerTempTable("1one") + checkAnswer(sql("select count(num) from 1one"), Row(10)) + } + } - checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } - sqlContext.dropTempTable("1one") + test("SPARK-10130 type coercion for IF should have children resolved first") { + val df = Seq((1, 1), (-1, 1)).toDF("key", "value") + df.registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd96d27..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -71,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b79..45d0ee4a8e74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ca298b243441..b91438baea06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") @@ -349,9 +348,9 @@ class StringFunctionsSuite extends QueryTest { // it will still use the interpretProjection if projection follows by a LocalRelation, // hence we add a filter operator. // See the optimizer rule `ConvertToLocalRelation` - val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + val df2 = Seq((5L, 4), (4L, 3), (4L, 3), (4L, 3), (3L, 2)).toDF("a", "b") checkAnswer( df2.filter("b>0").selectExpr("format_number(a, b)"), - Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) + Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index bd9729c431f3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 183dc3407b3a..eb275af101e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame @@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) ctx.dropTempTable("tmp_table") } @@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils { val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") - val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) - assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) ctx.dropTempTable("test_table") } } @@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { @@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("integerData") val result = - ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } @@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT g, SUM(v) as s | FROM groupData @@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT SUM(v) | FROM groupData @@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData @@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } @@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("udf that is transformed") { ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { ctx.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. - assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 89bad1bfdab0..219435dff5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -51,7 +51,7 @@ class UnsafeRowSuite extends SparkFunSuite { val bytesFromOffheapRow: Array[Byte] = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { - PlatformDependent.copyMemory( + Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, offheapRowPage.getBaseObject, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9181222f6922..b6d279ae4726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest { ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9bca4e7e660d..952637c5f9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + setupTestData() test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2c0879927a12..ab2644eb4581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) @@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - } - - override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - } - - before { ctx.cacheTable("pruningData") } - after { - ctx.uncacheTable("pruningData") + override protected def afterAll(): Unit = { + try { + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons @@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 79e903c2bbd4..8998f5111124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext -class ExchangeSuite extends SparkPlanTest { +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5582caa0d366..fad93b014c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ @@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = TestSQLContext + setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") + val rowRDD = ctx.sparkContext.parallelize(row :: Nil) + ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - dropTempTable("testLimit") + ctx.dropTempTable("testLimit") } - val origThreshold = conf.autoBroadcastJoinThreshold + val origThreshold = ctx.conf.autoBroadcastJoinThreshold val simpleTypes = NullType :: @@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + val origThreshold = ctx.conf.autoBroadcastJoinThreshold + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") val a = testData.as("a") - val b = table("tiny").as("b") + val b = ctx.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } @@ -157,13 +158,27 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("efficient limit -> project -> sort") { - val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } + + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } test("PartitioningCollection") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index dd08e9025a92..ef6ad59b71fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String -class RowFormatConvertersSuite extends SparkPlanTest { +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c @@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - TestSQLContext.createDataFrame(input), + ctx.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(TestSQLContext) + SparkPlan.currentContext.set(ctx) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a2c10fdaf6cd..8fa77b0fcb7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext -class SortSuite extends SparkPlanTest { +class SortSuite extends SparkPlanTest with SharedSQLContext { // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f46855edfe0d..3a87f374d94b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,29 +17,27 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.util._ + /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - protected def sqlContext: SQLContext = TestSQLContext +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def _sqlContext: SQLContext /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - sqlContext.implicits.localSeqToDataFrameHolder(data) + _sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -151,13 +149,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -172,7 +170,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -212,12 +210,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -280,10 +278,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = _sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 88bce0e319f9..3158458edb83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -19,25 +19,28 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ -class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + super.beforeAll() + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + try { + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } finally { + super.afterAll() + } } test("sort followed by limit") { @@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { } test("sorting updates peak execution memory") { - val sc = TestSQLContext.sparkContext + val sc = ctx.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), @@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = ctx.createDataFrame( + ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index e03473041c3e..d1f0b2b1fc52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String * * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. */ -class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ @@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() @@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with an empty map") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, @@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with empty records") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index a9515a03acf2..d3be568a8758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite { - +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) @@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, spill: Boolean): Unit = { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) val shuffleMemMgr = new TestShuffleMemoryManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d64..bd02c73a26ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { @@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 000000000000..5fdb82b06772 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { + + test("memory acquired on construction") { + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 92022ff23d2c..1174b27732f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} @@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} -import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { - - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - override def sqlContext: SQLContext = ctx // used by SQLTestUtils - - import ctx.sql - import ctx.implicits._ +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) @@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext - val relation0 = new JSONRelation( Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { "abd") ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) - checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 88% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 369df5653060..2864181cf91d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def _sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -36,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -47,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -65,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -80,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -96,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -150,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -158,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -167,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -190,7 +189,7 @@ trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,9 +197,8 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - lazy val singleRow: RDD[String] = - ctx.sparkContext.parallelize( - """{"a":123}""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala new file mode 100644 index 000000000000..45db619567a2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.ByteBuffer +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConverters.seqAsJavaListConverter +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord +import org.apache.hadoop.fs.Path +import org.apache.parquet.avro.AvroParquetWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit): Unit = { + logInfo( + s"""Writing Avro records with the following Avro schema into Parquet file: + | + |${schema.toString(true)} + """.stripMargin) + + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() + } + + test("required primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroPrimitives](path, AvroPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write( + AvroPrimitives.newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setStringColumn(s"val_$i") + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + }) + } + } + + test("optional primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroOptionalPrimitives](path, AvroOptionalPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = if (i % 3 == 0) { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(null) + .setMaybeIntColumn(null) + .setMaybeLongColumn(null) + .setMaybeFloatColumn(null) + .setMaybeDoubleColumn(null) + .setMaybeBinaryColumn(null) + .setMaybeStringColumn(null) + .build() + } else { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeIntColumn(i) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeFloatColumn(i.toFloat + 0.1f) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeStringColumn(s"val_$i") + .build() + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + if (i % 3 == 0) { + Row.apply(Seq.fill(7)(null): _*) + } else { + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + } + }) + } + } + + test("non-nullable arrays") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroNonNullableArrays](path, AvroNonNullableArrays.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = { + val builder = + AvroNonNullableArrays.newBuilder() + .setStringsColumn(Seq.tabulate(3)(i => s"val_$i").asJava) + + if (i % 3 == 0) { + builder.setMaybeIntsColumn(null).build() + } else { + builder.setMaybeIntsColumn(Seq.tabulate(3)(Int.box).asJava).build() + } + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(i => s"val_$i"), + if (i % 3 == 0) null else Seq.tabulate(3)(identity)) + }) + } + } + + ignore("nullable arrays (parquet-avro 1.7.0 does not properly support this)") { + // TODO Complete this test case after upgrading to parquet-mr 1.8+ + } + + test("SPARK-10136 array of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroArrayOfArray](path, AvroArrayOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroArrayOfArray.newBuilder() + .setIntArraysColumn( + Seq.tabulate(3, 3)((i, j) => i * 3 + j: Integer).map(_.asJava).asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) + }) + } + } + + test("map of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroMapOfArray](path, AvroMapOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroMapOfArray.newBuilder() + .setStringToIntsColumn( + Seq.tabulate(3) { i => + i.toString -> Seq.tabulate(3)(j => i + j: Integer).asJava + }.toMap.asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) + }) + } + } + + test("various complex types") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } + } + + def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { + def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Nested + .newBuilder() + .setNestedIntsColumn(Seq.tabulate(3)(j => i + j + m: Integer).asJava) + .setNestedStringColumn(s"val_${i + m}") + .build() + }.asJava + }.toMap.asJava + } + + ParquetAvroCompat + .newBuilder() + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}").asJava) + .setStringToIntColumn(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap.asJava) + .setComplexColumn(makeComplexColumn(i)) + .build() + } + + test("SPARK-9407 Push down predicates involving Parquet ENUM columns") { + import testImplicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala similarity index 54% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 57478931cd50..d85c564e3e8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -15,49 +15,41 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet -import java.io.File +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - protected var parquetStore: File = _ - - /** - * Optional path to a staging subdirectory which may be created during query processing - * (Hive does this). - * Parquet files under this directory will be ignored in [[readParquetSchema()]] - * @return an optional staging directory to ignore when scanning for parquet files. - */ - protected def stagingDir: Option[String] = None - - override protected def beforeAll(): Unit = { - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - Utils.deleteRecursively(parquetStore) +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String): MessageType = { + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => - status.getPath.getName.startsWith("_") || - stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) - } + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } + + protected def logParquetSchema(path: String): Unit = { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} + """.stripMargin) + } } object ParquetCompatibilityTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b6a7c4fbddbd..f067112cfca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,17 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.test.SharedSQLContext /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -39,8 +39,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, @@ -55,20 +54,22 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + val analyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters - }.flatten.reduceOption(_ && _) + }.flatten + assert(analyzedPredicate.nonEmpty) - assert(maybeAnalyzedPredicate.isDefined) - maybeAnalyzedPredicate.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(pred) + val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate) + assert(selectedFilters.nonEmpty) + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") maybeFilter.foreach { f => // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) assert(f.getClass === filterClass) } } - checker(query, expected) } } @@ -109,43 +110,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } - test("filter pushdown - short") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } - } - test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -154,13 +130,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -171,6 +147,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -179,13 +156,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -196,6 +173,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -204,13 +182,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -221,6 +199,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -229,13 +208,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -247,6 +226,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") checkFilterPredicate( '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) @@ -256,13 +236,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } @@ -274,6 +254,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( @@ -288,20 +269,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) checkBinaryFilterPredicate( '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b415da5b8c13..e6b0a2ea95e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -37,6 +37,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -369,6 +369,30 @@ class ParquetIOSuite extends QueryTest with ParquetTest { test("SPARK-6352 DirectParquetOutputCommitter") { val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + classOf[DirectParquetOutputCommitter].getCanonicalName) + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { @@ -390,7 +414,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(configuration) @@ -399,7 +424,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { configuration.set( "spark.sql.parquet.output.committer.class", - classOf[BogusParquetOutputCommitter].getCanonicalName) + classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) try { val message = intercept[SparkException] { @@ -425,12 +450,54 @@ class ParquetIOSuite extends QueryTest with ParquetTest { }.toString assert(errorMessage.contains("UnknownHostException")) } + + test("SPARK-7837 Do not close output writer twice when commitTask() fails") { + val clonedConf = new Configuration(configuration) + + // Using a output committer that always fail when committing a task, so that both + // `commitTask()` and `abortTask()` are invoked. + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) + + try { + // Before fixing SPARK-7837, the following code results in an NPE because both + // `commitTask()` and `abortTask()` try to close output writers. + + withTempPath { dir => + val m1 = intercept[SparkException] { + sqlContext.range(1).coalesce(1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m1.contains("Intentional exception for testing purposes")) + } + + withTempPath { dir => + val m2 = intercept[SparkException] { + val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + df.write.partitionBy("a").parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(m2.contains("Intentional exception for testing purposes")) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } } -class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { override def commitJob(jobContext: JobContext): Unit = { sys.error("Intentional exception for testing purposes") } } + +class TaskCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitTask(context: TaskAttemptContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11..ed8bafb10c60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String -import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala new file mode 100644 index 000000000000..b290429c2a02 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + + private def readParquetProtobufFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } + + test("unannotated array of primitive type") { + checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + } + + test("unannotated array of struct") { + checkAnswer( + readParquetProtobufFile("old-repeated-message.parquet"), + Row( + Seq( + Row("First inner", null, null), + Row(null, "Second inner", null), + Row(null, null, "Third inner")))) + + checkAnswer( + readParquetProtobufFile("proto-repeated-struct.parquet"), + Row( + Seq( + Row("0 - 1", "0 - 2", "0 - 3"), + Row("1 - 1", "1 - 2", "1 - 3")))) + + checkAnswer( + readParquetProtobufFile("proto-struct-with-array-many.parquet"), + Seq( + Row( + Seq( + Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"), + Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))), + Row( + Seq( + Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"), + Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))), + Row( + Seq( + Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"), + Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3"))))) + } + + test("struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("proto-struct-with-array.parquet"), + Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) + } + + test("unannotated array of struct with unannotated array") { + checkAnswer( + readParquetProtobufFile("nested-array-struct.parquet"), + Seq( + Row(2, Seq(Row(1, Seq(Row(3))))), + Row(5, Seq(Row(4, Seq(Row(6))))), + Row(8, Seq(Row(7, Seq(Row(9))))))) + } + + test("unannotated array of string") { + checkAnswer( + readParquetProtobufFile("proto-repeated-string.parquet"), + Seq( + Row(Seq("hello", "world")), + Row(Seq("good", "bye")), + Row(Seq("one", "two", "three")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala similarity index 72% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5c65a8ec57f0..b7b70c2bbbd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -15,22 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = ctx.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -202,4 +201,32 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { assert(Decimal("67123.45") === Decimal(decimal)) } } + + test("SPARK-10005 Schema merging for nested struct") { + val sqlContext = _sqlContext + import sqlContext.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + def append(df: DataFrame): Unit = { + df.write.mode(SaveMode.Append).parquet(path) + } + + // Note that both the following two DataFrames contain a single struct column with multiple + // nested fields. + append((1 to 2).map(i => Tuple1((i, i))).toDF()) + append((1 to 2).map(i => Tuple1((i, i, i))).toDF()) + + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer( + sqlContext.read.option("mergeSchema", "true").parquet(path), + Seq( + Row(Row(1, 1, null)), + Row(Row(2, 2, null)), + Row(Row(1, 1, 1)), + Row(Row(2, 2, 2)))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419..9dcbc1a047be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -15,20 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. @@ -585,6 +583,36 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin) + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 7 - " + + "parquet-protobuf primitive lists", + new StructType() + .add("f1", ArrayType(IntegerType, containsNull = false), nullable = false), + """message root { + | repeated int32 f1; + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type 8 - " + + "parquet-protobuf non-primitive lists", + { + val elementType = + new StructType() + .add("c1", StringType, nullable = true) + .add("c2", IntegerType, nullable = false) + + new StructType() + .add("f1", ArrayType(elementType, containsNull = false), nullable = false) + }, + """message root { + | repeated group f1 { + | optional binary c1 (UTF8); + | required int32 c2; + | } + |} + """.stripMargin) + // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST // ======================================================= diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 86% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 64e94056f209..5dbc7d1630f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => +private[sql] trait ParquetTest extends SQLTestUtils { + protected def _sqlContext: SQLContext + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) } /** @@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + _sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 1c532d78790d..b789c5a106e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -15,16 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext -class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 239deb797384..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0554e11d252b..53a0e53fd771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -15,80 +15,73 @@ * limitations under the License. */ -// TODO: uncomment the test here! It is currently failing due to -// bad interaction with org.apache.spark.sql.test.TestSQLContext. +package org.apache.spark.sql.execution.joins -// scalastyle:off -//package org.apache.spark.sql.execution.joins -// -//import scala.reflect.ClassTag -// -//import org.scalatest.BeforeAndAfterAll -// -//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -//import org.apache.spark.sql.functions._ -//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} -// -///** -// * Test various broadcast join operators with unsafe enabled. -// * -// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs -// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. -// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in -// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without -// * serializing the hashed relation, which does not happen in local mode. -// */ -//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { -// private var sc: SparkContext = null -// private var sqlContext: SQLContext = null -// -// /** -// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. -// */ -// override def beforeAll(): Unit = { -// super.beforeAll() -// val conf = new SparkConf() -// .setMaster("local-cluster[2,1,1024]") -// .setAppName("testing") -// sc = new SparkContext(conf) -// sqlContext = new SQLContext(sc) -// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) -// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) -// } -// -// override def afterAll(): Unit = { -// sc.stop() -// sc = null -// sqlContext = null -// } -// -// /** -// * Test whether the specified broadcast join updates the peak execution memory accumulator. -// */ -// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { -// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { -// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") -// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") -// // Comparison at the end is for broadcast left semi join -// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") -// val df3 = df1.join(broadcast(df2), joinExpression, joinType) -// val plan = df3.queryExecution.executedPlan -// assert(plan.collect { case p: T => p }.size === 1) -// plan.executeCollect() -// } -// } -// -// test("unsafe broadcast hash join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") -// } -// -// test("unsafe broadcast hash outer join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") -// } -// -// test("unsafe broadcast left semi join updates peak execution memory") { -// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") -// } -// -//} -// scalastyle:on +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + private var sc: SparkContext = null + private var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b..4c9187a9a710 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,11 +22,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { @@ -35,7 +37,8 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -45,11 +48,13 @@ class HashedRelationSuite extends SparkFunSuite { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -62,17 +67,19 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -94,5 +101,37 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) + } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 000000000000..cc649b9bd4c4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val myUpperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + private lazy val myLowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + private lazy val myTestData = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + () => (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4..a1a617d7b739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,89 +1,214 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} + +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testOuterJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + joinType: JoinType, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + } + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3d..baa86e320d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,100 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testLeftSemiJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala new file mode 100644 index 000000000000..80006bf077fe --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -0,0 +1,577 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") + val f = () => { + l += 1L + l.add(1L) + } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = ctx.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = ctx.listener.executionIdToData.keySet + df.collect() + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = ctx.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = ctx.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = ctx.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = ctx.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = ctx.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 69a561e16aa1..80d1e8895694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -15,22 +15,22 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SQLListenerSuite extends SparkFunSuite { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private def createTestDataFrame: DataFrame = { - import TestSQLContext.implicits._ Seq( (1, 1), (2, 2) @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("basic") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f..0edac0848c3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -444,5 +443,4 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 84b52ca2c733..5dc3a2c07b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -23,11 +23,13 @@ import java.util.Properties import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala deleted file mode 100644 index d22160f5384f..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.metric - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import scala.collection.mutable - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.util.Utils - - -class SQLMetricsSuite extends SparkFunSuite { - - test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") - val f = () => { l += 1L } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - } - - test("IntSQLMetric should not box Int") { - val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") - val f = () => { l += 1 } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - } - - test("Normal accumulator should do boxing") { - // We need this test to make sure BoxingFinder works. - val l = TestSQLContext.sparkContext.accumulator(0L) - val f = () => { l += 1L } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") - } - } -} - -private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) - -/** - * If `method` is null, search all methods of this class recursively to find if they do some boxing. - * If `method` is specified, only search this method of the class to speed up the searching. - * - * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. - */ -private class BoxingFinder( - method: MethodIdentifier[_] = null, - val boxingInvokes: mutable.Set[String] = mutable.Set.empty, - visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) - extends ClassVisitor(ASM4) { - - private val primitiveBoxingClassName = - Set("java/lang/Long", - "java/lang/Double", - "java/lang/Integer", - "java/lang/Float", - "java/lang/Short", - "java/lang/Character", - "java/lang/Byte", - "java/lang/Boolean") - - override def visitMethod( - access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): - MethodVisitor = { - if (method != null && (method.name != name || method.desc != desc)) { - // If method is specified, skip other methods. - return new MethodVisitor(ASM4) {} - } - - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { - if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { - if (primitiveBoxingClassName.contains(owner)) { - // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) - boxingInvokes.add(s"$owner.$name") - } - } else { - // scalastyle:off classforname - val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, - Thread.currentThread.getContextClassLoader) - // scalastyle:on classforname - val m = MethodIdentifier(classOfMethodOwner, name, desc) - if (!visitedMethods.contains(m)) { - // Keep track of visited methods to avoid potential infinite cycles - visitedMethods += m - BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => - visitedMethods += m - cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) - } - } - } - } - } - } -} - -private object BoxingFinder { - - def getClassReader(cls: Class[_]): Option[ClassReader] = { - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - val baos = new ByteArrayOutputStream(128) - // Copy data over, before delegating to ClassReader - - // else we can run out of open file handles. - Utils.copyStream(resourceStream, baos, true) - // ASM4 doesn't support Java 8 classes, which requires ASM5. - // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), - // then ClassReader will throw IllegalArgumentException, - // However, since this is only for testing, it's safe to skip these classes. - try { - Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) - } catch { - case _: IllegalArgumentException => None - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala deleted file mode 100644 index bfa427349ff6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.nio.ByteBuffer -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.fs.Path -import org.apache.parquet.avro.AvroParquetWriter - -import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} - -class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { - import ParquetCompatibilityTest._ - - override val sqlContext: SQLContext = TestSQLContext - - override protected def beforeAll(): Unit = { - super.beforeAll() - - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() - } - - test("Read Parquet file generated by parquet-avro") { - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) - - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - Row( - i % 2 == 0, - i, - i.toLong * 10, - i.toFloat + 0.1f, - i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, - Seq.tabulate(3) { n => - (i + n).toString -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) - } - - def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { - def nullable[T <: AnyRef] = makeNullable[T](i) _ - - def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { - mapAsJavaMap(Seq.tabulate(3) { n => - (i + n).toString -> seqAsJavaList(Seq.tabulate(3) { m => - Nested - .newBuilder() - .setNestedIntsColumn(seqAsJavaList(Seq.tabulate(3)(j => i + j + m))) - .setNestedStringColumn(s"val_${i + m}") - .build() - }) - }.toMap) - } - - ParquetAvroCompat - .newBuilder() - .setBoolColumn(i % 2 == 0) - .setIntColumn(i) - .setLongColumn(i.toLong * 10) - .setFloatColumn(i.toFloat + 0.1f) - .setDoubleColumn(i.toDouble + 0.2d) - .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) - .setStringColumn(s"val_$i") - - .setMaybeBoolColumn(nullable(i % 2 == 0: java.lang.Boolean)) - .setMaybeIntColumn(nullable(i: Integer)) - .setMaybeLongColumn(nullable(i.toLong: java.lang.Long)) - .setMaybeFloatColumn(nullable(i.toFloat + 0.1f: java.lang.Float)) - .setMaybeDoubleColumn(nullable(i.toDouble + 0.2d: java.lang.Double)) - .setMaybeBinaryColumn(nullable(ByteBuffer.wrap(s"val_$i".getBytes))) - .setMaybeStringColumn(nullable(s"val_$i")) - - .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}")) - .setStringToIntColumn( - mapAsJavaMap(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap)) - .setComplexColumn(makeComplexColumn(i)) - - .build() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85d..9bc3f6bcf6fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,28 +19,32 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { @@ -51,7 +55,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -75,7 +79,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -92,7 +96,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -107,7 +111,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -122,7 +126,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -139,7 +143,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -158,7 +162,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -175,7 +179,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -188,7 +192,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -199,7 +203,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 1a4d41b02ca6..853707c036c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -18,11 +18,40 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + class FakeSourceOne extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -35,7 +64,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { class FakeSourceTwo extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -48,7 +77,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { class FakeSourceThree extends RelationProvider with DataSourceRegister { - def format(): String = "gathering quorum" + def shortName(): String = "gathering quorum" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -58,28 +87,3 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } -// please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { - - test("data sources with the same name") { - intercept[RuntimeException] { - caseInsensitiveContext.read.format("Fluet da Bomb").load() - } - } - - test("load data source from format alias") { - caseInsensitiveContext.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("specify full classname with duplicate formats") { - caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("Loading Orc") { - intercept[ClassNotFoundException] { - caseInsensitiveContext.read.format("orc").load() - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 84855ce45e91..5f8514e1a241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580..d74d29fb0beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,18 +17,23 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { +private[sql] abstract class DataSourceTest extends QueryTest { + protected def _sqlContext: SQLContext + // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(_sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3..c81c3d398280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -56,6 +57,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v @@ -95,11 +97,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - import caseInsensitiveContext.sql - - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cdbfaf6455fe..78bd3e558296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") @@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { @@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c86ddd7c83e5..79b6e9b45c00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,21 +19,21 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class PartitionedWriteSuite extends QueryTest { - import TestSQLContext.implicits._ +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("write many partitions") { val path = Utils.createTempDir() path.delete() - val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + val df = ctx.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest { val path = Utils.createTempDir() path.delete() - val base = TestSQLContext.range(100) + val base = ctx.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 0d5183444af7..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 3cbf5467b253..27d1cd92fca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,14 +22,39 @@ import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 31730a3d3f8d..f18546b4c2d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,22 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index e34e0956d1fd..12af8068c398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource @@ -95,8 +96,8 @@ case class AllDataTypesScan( } } -class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest { Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen @@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala new file mode 100644 index 000000000000..152c9c8459de --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.{IOException, InputStream} + +import scala.sys.process.BasicIO + +object ProcessTestUtils { + class ProcessOutputCapturer(stream: InputStream, capture: String => Unit) extends Thread { + this.setDaemon(true) + + override def run(): Unit = { + try { + BasicIO.processFully(capture)(stream) + } catch { case _: IOException => + // Ignores the IOException thrown when the process termination, which closes the input + // stream abruptly. + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 000000000000..1374a97476ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def _sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val testData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = _sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = _sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + _sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec..cdd691e03589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -21,15 +21,71 @@ import java.io.File import java.util.UUID import scala.util.Try +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils -trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def _sqlContext: SQLContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = _sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def configuration: Configuration = { + _sqlContext.sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite => */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(_sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) + case (key, None) => _sqlContext.conf.unsetConf(key) } } } @@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite => * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(_sqlContext.dropTempTable) } /** @@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite => protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + _sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite => val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + _sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite => * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + _sqlContext.sql(s"USE $db") + try f finally _sqlContext.sql(s"USE default") + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(_sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..8a061b6bc690 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.{ColumnName, SQLContext} + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +private[sql] trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx + protected override def _sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + // This must be duplicated here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala similarity index 54% rename from sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b3a4231da91c..92ef2f7d74ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,40 +17,36 @@ package org.apache.spark.sql.test -import scala.language.implicitConversions - import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) } + // Use fewer partitions to speed up testing + protected[sql] override def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() } + private object testData extends SQLTestData { + protected override def _sqlContext: SQLContext = self + } } - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-avro.sh similarity index 76% rename from sql/core/src/test/scripts/gen-code.sh rename to sql/core/src/test/scripts/gen-avro.sh index 5d8d8ad08555..48174b287fd7 100755 --- a/sql/core/src/test/scripts/gen-code.sh +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -22,10 +22,9 @@ cd - rm -rf $BASEDIR/gen-java mkdir -p $BASEDIR/gen-java -thrift\ - --gen java\ - -out $BASEDIR/gen-java\ - $BASEDIR/thrift/parquet-compat.thrift - -avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr -avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 000000000000..ada432c68ab9 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift index fa5ed8c62306..98bf778aec5d 100644 --- a/sql/core/src/test/thrift/parquet-compat.thrift +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -15,7 +15,7 @@ * limitations under the License. */ -namespace java org.apache.spark.sql.parquet.test.thrift +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift enum Suit { SPADES, diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 2dfbcb2425a3..3566c87dd248 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -86,6 +86,13 @@ selenium-java test + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + ${project.version} + test + target/scala-${scala.binary.version}/classes diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 9c047347cb58..dd9fef9206d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -76,7 +76,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } @@ -152,16 +152,26 @@ object HiveThriftServer2 extends Logging { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - var onlineSessionNum: Int = 0 - val sessionList = new mutable.LinkedHashMap[String, SessionInfo] - val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] - val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) - val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - var totalRunning = 0 - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -173,13 +183,15 @@ object HiveThriftServer2 extends Logging { } def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { - val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) - sessionList.put(sessionId, info) - onlineSessionNum += 1 - trimSessionIfNecessary() + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } } - def onSessionClosed(sessionId: String): Unit = { + def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 trimSessionIfNecessary() @@ -190,7 +202,7 @@ object HiveThriftServer2 extends Logging { sessionId: String, statement: String, groupId: String, - userName: String = "UNKNOWN"): Unit = { + userName: String = "UNKNOWN"): Unit = synchronized { val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) @@ -200,27 +212,29 @@ object HiveThriftServer2 extends Logging { totalRunning += 1 } - def onStatementParsed(id: String, executionPlan: String): Unit = { + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED } def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { - executionList(id).finishTimestamp = System.currentTimeMillis - executionList(id).detail = errorMessage - executionList(id).state = ExecutionState.FAILED - totalRunning -= 1 - trimExecutionIfNecessary() + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } } - def onStatementFinish(id: String): Unit = { + def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 trimExecutionIfNecessary() } - private def trimExecutionIfNecessary() = synchronized { + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => @@ -229,7 +243,7 @@ object HiveThriftServer2 extends Logging { } } - private def trimSessionIfNecessary() = synchronized { + private def trimSessionIfNecessary() = { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d3886142b388..7799704c819d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -39,7 +39,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -114,7 +114,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 10c83d8b27a2..e990bd06011f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -39,14 +39,16 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - generateBasicStats() ++ -
    ++ -

    - {listener.onlineSessionNum} session(s) are online, - running {listener.totalRunning} SQL statement(s) -

    ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
    ++ +

    + {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

    ++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } @@ -65,11 +67,11 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(): Seq[Node] = { - val numStatement = listener.executionList.size + val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.executionList.values + val dataRows = listener.getExecutionList def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -136,15 +138,15 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = - listener.sessionList.values + val dataRows = sessionList val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/sql/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 3b01afa603ce..af16cb31df18 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -40,21 +40,22 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val sessionStat = listener.sessionList.find(stat => { - stat._1 == parameterId - }).getOrElse(null) - require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") val content = - generateBasicStats() ++ -
    ++ -

    - User {sessionStat._2.userName}, - IP {sessionStat._2.ip}, - Session created at {formatDate(sessionStat._2.startTimestamp)}, - Total run {sessionStat._2.totalExecution} SQL -

    ++ - generateSQLStatsTable(sessionStat._2.sessionId) + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
    ++ +

    + User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

    ++ + generateSQLStatsTable(sessionStat.sessionId) + } UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } @@ -73,13 +74,13 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(sessionID: String): Seq[Node] = { - val executionList = listener.executionList - .filter(_._2.sessionId == sessionID) + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -146,10 +147,11 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { val dataRows = - listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + sessionList.sortBy(_.startTimestamp).reverse.map ( session => Seq( session.userName, session.ip, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 94fd8a6bb60b..4eabeaa6735e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,9 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { - override val name = "SQL" + override val name = "JDBC/ODBC Server" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 121b3e077f71..e59a14ec00d5 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,18 +18,19 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.sql.Timestamp +import java.util.Date import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} -import scala.util.Failure +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary @@ -70,6 +71,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queriesString = queries.map(_ + "\n").mkString + val command = { val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" @@ -83,13 +87,14 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. - val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object def captureOutput(source: String)(line: String): Unit = lock.synchronized { - buffer += s"$source> $line" + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. + buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" + // If we haven't found all expected answers and another expected answer comes up... if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { next += 1 @@ -98,48 +103,27 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { foundAllExpectedAnswers.trySuccess(()) } } else { - errorResponses.foreach( r => { + errorResponses.foreach { r => if (line.startsWith(r)) { foundAllExpectedAnswers.tryFailure( new RuntimeException(s"Failed with error line '$line'")) - }}) - } - } - - // Searching expected output line from both stdout and stderr of the CLI process - val process = (Process(command, None) #< queryStream).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - - // catch the output value - class exitCodeCatcher extends Runnable { - var exitValue = 0 - - override def run(): Unit = { - try { - exitValue = process.exitValue() - } catch { - case rte: RuntimeException => - // ignored as it will get triggered when the process gets destroyed - logDebug("Ignoring exception while waiting for exit code", rte) - } - if (exitValue != 0) { - // process exited: fail fast - foundAllExpectedAnswers.tryFailure( - new RuntimeException(s"Failed with exit code $exitValue")) + } } } } - // spin off the code catche thread. No attempt is made to kill this - // as it will exit once the launched process terminates. - val codeCatcherThread = new Thread(new exitCodeCatcher()) - codeCatcherThread.start() + + val process = new ProcessBuilder(command: _*).start() + + val stdinWriter = new OutputStreamWriter(process.getOutputStream) + stdinWriter.write(queriesString) + stdinWriter.flush() + stdinWriter.close() + + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - Await.ready(foundAllExpectedAnswers.future, timeout) - foundAllExpectedAnswers.future.value match { - case Some(Failure(t)) => throw t - case _ => - } + Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => val message = s""" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 17e7044c46ec..ded42bca9971 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -22,10 +22,9 @@ import java.net.URL import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} -import scala.concurrent.ExecutionContext.Implicits.global -import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -38,11 +37,12 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{Ignore, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkFunSuite} object TestData { def getTestDataFilePath(name: String): URL = { @@ -53,7 +53,6 @@ object TestData { val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") } -@Ignore // SPARK-9606 class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary @@ -380,7 +379,6 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } -@Ignore // SPARK-9606 class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.http @@ -484,7 +482,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl val tempLog4jConf = Utils.createTempDir().getCanonicalPath Files.write( - """log4j.rootCategory=INFO, console + """log4j.rootCategory=DEBUG, console |log4j.appender.console=org.apache.log4j.ConsoleAppender |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout @@ -493,7 +491,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl new File(s"$tempLog4jConf/log4j.properties"), UTF_8) - tempLog4jConf // + File.pathSeparator + sys.props("java.class.path") + tempLog4jConf } s"""$startScript @@ -521,7 +519,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl */ val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" - val SERVER_STARTUP_TIMEOUT = 1.minute + val SERVER_STARTUP_TIMEOUT = 3.minutes private def startThriftServer(port: Int, attempt: Int) = { warehousePath = Utils.createTempDir() @@ -543,17 +541,22 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") - val env = Seq( - // Disables SPARK_TESTING to exclude log4j.properties in test directories. - "SPARK_TESTING" -> "0", - // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started - // at a time, which is not Jenkins friendly. - "SPARK_PID_DIR" -> pidDir.getCanonicalPath) - - logPath = Process(command, None, env: _*).lines.collectFirst { - case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) - }.getOrElse { - throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + logPath = { + val lines = Utils.executeAndGetOutput( + command = command, + extraEnvironment = Map( + // Disables SPARK_TESTING to exclude log4j.properties in test directories. + "SPARK_TESTING" -> "0", + // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be + // started at a time, which is not Jenkins friendly. + "SPARK_PID_DIR" -> pidDir.getCanonicalPath), + redirectStderr = true) + + lines.split("\n").collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } } val serverStarted = Promise[Unit]() @@ -561,30 +564,36 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl // Ensures that the following "tail" command won't fail. logPath.createNewFile() val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) - val failureLines = Seq("HiveServer2 is stopped", "Exception in thread", "Error:") - logTailingProcess = + + logTailingProcess = { + val command = s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}".split(" ") // Using "-n +0" to make sure all lines in the log file are checked. - Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( - (line: String) => { - diagnosisBuffer += line - successLines.foreach(r => { - if (line.contains(r)) { - serverStarted.trySuccess(()) - } - }) - failureLines.foreach(r => { - if (line.contains(r)) { - serverStarted.tryFailure(new RuntimeException(s"Failed with output '$line'")) - } - }) - })) + val builder = new ProcessBuilder(command: _*) + val captureOutput = (line: String) => diagnosisBuffer.synchronized { + diagnosisBuffer += line + + successLines.foreach { r => + if (line.contains(r)) { + serverStarted.trySuccess(()) + } + } + } + + val process = builder.start() + + new ProcessOutputCapturer(process.getInputStream, captureOutput).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput).start() + process + } Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue() + Utils.executeAndGetOutput( + command = Seq(stopScript), + extraEnvironment = Map("SPARK_PID_DIR" -> pidDir.getCanonicalPath)) Thread.sleep(3.seconds.toMillis) warehousePath.delete() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de45..bf431cd6b026 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 567d7fa12ff1..17cc83087fb1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -189,6 +189,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options // into the isolated client loader val metadataConf = new HiveConf() + + val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("defalt warehouse location is " + defaltWarehouseLocation) + // `configure` goes second to override other settings. val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure @@ -231,7 +235,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = @@ -284,12 +292,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -303,7 +312,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { case relation: MetastoreRelation => @@ -531,7 +541,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7198a32df4a0..bbe8c1911bf8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -86,9 +85,9 @@ private[hive] object HiveSerDe { serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) val key = source.toLowerCase match { - case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" - case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" - case _ => source.toLowerCase + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s } serdeMap.get(key) @@ -175,10 +174,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { + def invalidateTable(tableIdent: TableIdentifier): Unit = { + val databaseName = tableIdent.database.getOrElse(client.currentDatabase) + val tableName = tableIdent.table + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) } @@ -188,6 +190,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * Creates a data source table (a table created with USING clause) in Hive's metastore. * Returns true when the table has been created. Otherwise, false. */ + // TODO: Remove this in SPARK-10104. def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -204,7 +207,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive isExternal) } - private def createDataSourceTable( + def createDataSourceTable( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], @@ -309,11 +312,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val hiveTable = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - logInfo { - "Persisting data source relation with a single input path into Hive metastore in Hive " + - s"compatible format. Input path: ${relation.paths.head}" + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) } - newHiveCompatibleMetastoreTable(relation, serde) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { @@ -352,10 +375,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { + hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + } + + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val database = tableIdent.database.getOrElse(client.currentDatabase) + new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + new Path(client.getDatabase(database).location), + tableIdent.table.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { @@ -616,7 +645,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - desc.name, + TableIdentifier(desc.name), hive.conf.defaultDataSourceName, temporary = false, Array.empty[String], @@ -739,7 +768,7 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -888,6 +917,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c3f29350101d..ad33dee555dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -729,6 +729,17 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } + case _ => throw new SemanticException( s"Unrecognized file format in STORED AS clause: ${child.getText}") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cd6cd322c94e..d38ad9127327 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,14 +83,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => - ExecutedCommand( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + val cmd = CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) + ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index a82e152dcda2..3811c152a7ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -88,6 +88,9 @@ private[hive] case class HiveTable( */ private[hive] trait ClientInterface { + /** Returns the Hive Version of this client. */ + def version: HiveVersion + /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3d05b583cf9e..f49c97de8ff4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 6e826ce55220..8fc8935b1dc3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -429,7 +429,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -439,7 +439,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -461,6 +461,13 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, TimeUnit.MILLISECONDS).asInstanceOf[Long] } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + } private[client] class Shim_v1_0 extends Shim_v0_14 { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index a7d5a991948d..785603750841 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -42,11 +42,18 @@ private[hive] object IsolatedClientLoader { def forVersion( version: String, config: Map[String, String] = Map.empty, - ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion, ivyPath)) - new IsolatedClientLoader(hiveVersion(version), files, config) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 0503691a4424..b1b8439efa01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 05a78930afe3..d1699dd53681 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -120,9 +122,10 @@ case class AddFile(path: String) extends RunnableCommand { } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], @@ -130,9 +133,24 @@ case class CreateMetastoreDataSource( managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -144,13 +162,13 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, Array.empty[String], provider, @@ -161,9 +179,10 @@ case class CreateMetastoreDataSource( } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -171,19 +190,34 @@ case class CreateMetastoreDataSourceAsSelect( query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(Seq(tableName))) { + if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -200,7 +234,7 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = @@ -249,7 +283,7 @@ case class CreateMetastoreDataSourceAsSelect( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, Some(resolved.relation.schema), partitionColumns, provider, @@ -258,7 +292,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.refreshTable(tableName) + hiveContext.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 86142e5d66f3..b3d9f7f71a27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -107,6 +107,11 @@ private[orc] object OrcFilters extends Logging { .filter(isSearchableLiteral) .map(builder.equals(attribute, _)) + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + case LessThan(attribute, value) => Option(value) .filter(isSearchableLiteral) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0c344c63fde3..9f4f8b5789af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -49,9 +48,9 @@ import scala.collection.JavaConversions._ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "orc" + override def shortName(): String = "orc" - def createRelation( + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -144,7 +143,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 296cc5c5e0b0..4da86636ac10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -25,10 +25,8 @@ import scala.language.implicitConversions import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry -import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ @@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ @@ -154,7 +152,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) @@ -276,10 +274,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), TestTable("episodes", s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", @@ -312,10 +307,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |STORED AS avro |TBLPROPERTIES ( | 'avro.schema.literal'='{ | "type": "record", diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 21b053f07a3b..a30dfa554eab 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -92,7 +92,7 @@ public void testUDAF() { DataFrame aggregatedDF = df.groupBy() .agg( - udaf.apply(true, col("value")), + udaf.distinct(col("value")), udaf.apply(col("value")), registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index a2247e3da155..2961b803f14a 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -65,7 +65,7 @@ public MyDoubleAvg() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index da29e24d267d..c71882a6e7be 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -60,7 +60,7 @@ public MyDoubleSum() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } diff --git a/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 new file mode 100644 index 000000000000..9a276bc794c0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Column pruning - non-trivial top project with aliases - query test-0-515e406ffb23f6fd0d8cd34c2b25fbe6 @@ -0,0 +1,3 @@ +476 +172 +622 diff --git a/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f new file mode 100644 index 000000000000..444039e75fba --- /dev/null +++ b/sql/hive/src/test/resources/golden/Partition pruning - non-partitioned, non-trivial project - query test-0-eabbebd5c1d127b1605bfec52d7b7f3f @@ -0,0 +1,500 @@ +476 +172 +622 +54 +330 +818 +510 +556 +196 +968 +530 +386 +802 +300 +546 +448 +738 +132 +256 +426 +292 +812 +858 +748 +304 +938 +290 +990 +74 +654 +562 +554 +418 +30 +164 +806 +332 +834 +860 +504 +584 +438 +574 +306 +386 +676 +892 +918 +788 +474 +964 +348 +826 +988 +414 +398 +932 +416 +348 +798 +792 +494 +834 +978 +324 +754 +794 +618 +730 +532 +878 +684 +734 +650 +334 +390 +950 +34 +226 +310 +406 +678 +0 +910 +256 +622 +632 +114 +604 +410 +298 +876 +690 +258 +340 +40 +978 +314 +756 +442 +184 +222 +94 +144 +8 +560 +70 +854 +554 +416 +712 +798 +338 +764 +996 +250 +772 +874 +938 +384 +572 +374 +352 +108 +918 +102 +276 +206 +478 +426 +432 +860 +556 +352 +578 +442 +130 +636 +664 +622 +550 +274 +482 +166 +666 +360 +568 +24 +460 +362 +134 +520 +808 +768 +978 +706 +746 +544 +276 +434 +168 +696 +932 +116 +16 +822 +460 +416 +696 +48 +926 +862 +358 +344 +84 +258 +316 +238 +992 +0 +644 +394 +936 +786 +908 +200 +596 +398 +382 +836 +192 +52 +330 +654 +460 +410 +240 +262 +102 +808 +86 +872 +312 +938 +936 +616 +190 +392 +576 +962 +914 +196 +564 +394 +374 +636 +636 +818 +940 +274 +738 +632 +338 +826 +170 +154 +0 +980 +174 +728 +358 +236 +268 +790 +564 +276 +476 +838 +30 +236 +144 +180 +614 +38 +870 +20 +554 +546 +612 +448 +618 +778 +654 +484 +738 +784 +544 +662 +802 +484 +904 +354 +452 +10 +994 +804 +792 +634 +790 +116 +70 +672 +190 +22 +336 +68 +458 +466 +286 +944 +644 +996 +320 +390 +84 +642 +860 +238 +978 +916 +156 +152 +82 +446 +984 +298 +898 +436 +456 +276 +906 +60 +418 +128 +936 +152 +148 +684 +138 +460 +66 +736 +206 +592 +226 +432 +734 +688 +334 +548 +438 +478 +970 +232 +446 +512 +526 +140 +974 +960 +802 +576 +382 +10 +488 +876 +256 +934 +864 +404 +632 +458 +938 +926 +560 +4 +70 +566 +662 +470 +160 +88 +386 +642 +670 +208 +932 +732 +350 +806 +966 +106 +210 +514 +812 +818 +380 +812 +802 +228 +516 +180 +406 +524 +696 +848 +24 +792 +402 +434 +328 +862 +908 +956 +596 +250 +862 +328 +848 +374 +764 +10 +140 +794 +960 +582 +48 +702 +510 +208 +140 +326 +876 +238 +828 +400 +982 +474 +878 +720 +496 +958 +610 +834 +398 +888 +240 +858 +338 +886 +646 +650 +554 +460 +956 +356 +936 +620 +634 +666 +986 +920 +414 +498 +530 +960 +166 +272 +706 +344 +428 +924 +466 +812 +266 +350 +378 +908 +750 +802 +842 +814 +768 +512 +52 +268 +134 +768 +758 +36 +924 +984 +200 +596 +18 +682 +996 +292 +916 +724 +372 +570 +696 +334 +36 +546 +366 +562 +688 +194 +938 +630 +168 +56 +74 +896 +304 +696 +614 +388 +828 +954 +444 +252 +180 +338 +806 +800 +400 +194 diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index 92eaf1f2795b..fea3404769d9 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -48,9 +48,14 @@ log4j.logger.hive.log=OFF log4j.additivity.parquet.hadoop.ParquetRecordReader=false log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF + log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF log4j.additivity.org.apache.hadoop.hive.ql.io.RCFile=false log4j.logger.org.apache.hadoop.hive.ql.io.RCFile=ERROR - diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 332c3ec0c28b..574624d501f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{Row, SaveMode, SQLContext} import org.apache.spark.{Logging, SparkFunSuite} @@ -53,9 +53,13 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + import testImplicits._ - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) Seq( "parquet" -> ( @@ -88,10 +92,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -117,10 +121,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278..fe0db5228de1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.{QueryTest, Row, SQLContext} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ + private val ctx = TestHive + override def _sqlContext: SQLContext = ctx test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b8d41065d3f0..dc2d85f48624 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp +import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.sys.process.{ProcessLogger, Process} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -37,6 +41,8 @@ import org.scalatest.time.SpanSugar._ class HiveSparkSubmitSuite extends SparkFunSuite with Matchers + // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we + // add a timestamp to provide more diagnosis information. with ResetSystemProperties with Timeouts { @@ -50,13 +56,15 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -68,6 +76,8 @@ class HiveSparkSubmitSuite "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -79,7 +89,21 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) + runSparkSubmit(args) + } + + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) runSparkSubmit(args) } @@ -90,28 +114,44 @@ class HiveSparkSubmitSuite val history = ArrayBuffer.empty[String] val commands = Seq("./bin/spark-submit") ++ args val commandLine = commands.mkString("'", "' '", "'") - val process = Process( - commands, - new File(sparkHome), - "SPARK_TESTING" -> "1", - "SPARK_HOME" -> sparkHome - ).run(ProcessLogger( + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" // scalastyle:off println - (line: String) => { println(s"stdout> $line"); history += s"out> $line"}, - (line: String) => { println(s"stderr> $line"); history += s"err> $line" } + println(logLine) // scalastyle:on println - )) + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.exitValue() } + val exitCode = failAfter(180.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised Thread.sleep(1000) val historyLog = history.mkString("\n") - fail(s"$commandLine returned with exit code $exitCode." + - s" See the log4j logs for more detail." + - s"\n$historyLog") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } } } catch { case to: TestFailedDueToTimeoutException => @@ -205,7 +245,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -231,3 +271,46 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6..d3388a9429e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") @@ -42,7 +41,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def afterAll(): Unit = { catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") @@ -55,7 +53,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), Row("hivelisttablessuitetable", false)) @@ -69,9 +66,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - checkAnswer( - allTables.filter("tableName = 'indblisttablessuitetable'"), - Row("indblisttablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b73d6665755d..20a50586d520 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,7 +22,6 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging @@ -32,7 +31,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -42,7 +41,8 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll with Logging { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20..997c667ec0d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -19,15 +19,22 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override val _sqlContext: HiveContext = TestHive + private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) + private def checkTablePath(dbName: String, tableName: String): Unit = { + // val hiveContext = sqlContext.asInstanceOf[HiveContext] + val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) + val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + + assert(metastoreTable.serdeProperties("path") === expectedPath) + } + test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { @@ -38,6 +45,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") } } @@ -46,6 +55,58 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } } } @@ -60,6 +121,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -69,6 +132,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -131,7 +196,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { } } - test("Refreshes a table in a non-default database") { + test("Refreshes a table in a non-default database - with USE") { import org.apache.spark.sql.functions.lit withTempDatabase { db => @@ -152,8 +217,94 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + sqlContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) } } } } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + sqlContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index f00d3754c364..13452e71a1b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -20,65 +20,67 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - override val sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext /** * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ - override val stagingDir: Option[String] = - Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - override protected def beforeAll(): Unit = { - super.beforeAll() + test("Read Parquet file generated by parquet-hive") { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath - withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { - sqlContext.sql( - s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '${parquetStore.getCanonicalPath}' - """.stripMargin) + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") - } - } - } + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } - override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") - } + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) - test("Read Parquet file generated by parquet-hive") { - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |$schema + """.stripMargin) - // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. - // Have to assume all BINARY values are strings here. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), makeRows) + } + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d..7ee1c8d13aa3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9..119663af1887 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql._ -import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ @@ -479,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(0, null, 1, 1, null, 0) :: Nil) } + test("test Last implemented based on AggregateExpression1") { + // TODO: Remove this test once we remove AggregateExpression1. + import org.apache.spark.sql.functions._ + val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + + checkAnswer( + df.groupBy("i").agg(last("j")), + df + ) + } + } + test("error handling") { withSQLConf("spark.sql.useAggregate2" -> "false") { val errorMessage = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 44c5b80392fa..11d7a872dff0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils { - - def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fa7ae3fa2e1..55ecbd5b5f21 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("UDTF") { sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") @@ -1137,4 +1138,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(CalendarInterval.fromString( "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3..9aca40f15ac1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da0..f7ba20ff41d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - + protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ import sqlContext.sparkContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c4bc60086f6e..34d3434569f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { * A collection of tests for parquet data with various forms of partitioning. */ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706..b4640b161628 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..ed6d512ab36f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d..cb4cedddbfdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,12 +23,12 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + override val dataSourceName: String = "parquet" import sqlContext._ import sqlContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 1813cc33226d..e8975e5f5cd0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -50,33 +50,3 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } } - -class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = - classOf[org.apache.spark.sql.json.DefaultSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a69d331b6e5..5bbca14bad32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.sql.Date + import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration @@ -34,9 +36,8 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ val dataSourceName: String @@ -554,6 +555,55 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } + + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") + withTempDir { file => + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) + } + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") + } + } + + test("SPARK-9899 Disable customized output committer when speculation is on") { + val clonedConf = new Configuration(configuration) + val speculationEnabled = + sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + try { + withTempPath { dir => + // Enables task speculation + sqlContext.sparkContext.conf.set("spark.speculation", "true") + + // Uses a customized output committer which always fails + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + + // Code below shouldn't throw since customized output committer should be disabled. + val df = sqlContext.range(10).coalesce(1) + df.write.format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2780d5b6adbc..6f6b449accc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -192,7 +192,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -203,7 +205,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 177e710ace54..b496d1f341a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -604,7 +604,7 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) @@ -691,7 +691,7 @@ class StreamingContext private[streaming] ( StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + ShutdownHookManager.removeShutdownHook(shutdownHookRef) } logInfo("StreamingContext stopped successfully") } @@ -725,7 +725,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 808dcc174cf9..214cd80108b9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -291,7 +291,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction[R, Void]) { foreachRDD(foreachFunc) } @@ -302,7 +302,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * * @deprecated As of release 0.9.0, replaced by foreachRDD */ - @Deprecated + @deprecated("Use foreachRDD", "0.9.0") def foreach(foreachFunc: JFunction2[R, Time, Void]) { foreachRDD(foreachFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 31ce8e1ec14d..620b8a36a2ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -84,7 +84,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( require( blockIds.length == walRecordHandles.length, s"Number of block Ids (${blockIds.length}) must be " + - s" same as number of WAL record handles (${walRecordHandles.length}})") + s" same as number of WAL record handles (${walRecordHandles.length})") require( isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 794dece370b2..421d60ae359f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -155,10 +155,17 @@ private[streaming] class BlockGenerator( /** * Push a single data item into the buffer. */ - def addData(data: Any): Unit = synchronized { + def addData(data: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -169,11 +176,18 @@ private[streaming] class BlockGenerator( * Push a single data item into the buffer. After buffering the data, the * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -185,13 +199,23 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ - def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] dataIterator.foreach { data => waitToPush() - currentBuffer += data + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } - listener.onAddData(dataIterator, metadata) } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -203,16 +227,21 @@ private[streaming] class BlockGenerator( def isStopped(): Boolean = state == StoppedAll /** Change the buffer to which single records are added to. */ - private def updateCurrentBuffer(time: Long): Unit = synchronized { + private def updateCurrentBuffer(time: Long): Unit = { try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[Any] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(receiverId, time - blockIntervalMs) - val newBlock = new Block(blockId, newBlockBuffer) - listener.onGenerateBlock(blockId) + var newBlock: Block = null + synchronized { + if (currentBuffer.nonEmpty) { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + val blockId = StreamBlockId(receiverId, time - blockIntervalMs) + listener.onGenerateBlock(blockId) + newBlock = new Block(blockId, newBlockBuffer) + } + } + + if (newBlock != null) { blocksForPushing.put(newBlock) // put is blocking when queue is full - logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { case ie: InterruptedException => @@ -226,9 +255,13 @@ private[streaming] class BlockGenerator( private def keepPushingBlocks() { logInfo("Started block pushing thread") - def isGeneratingBlocks = synchronized { state == Active || state == StoppedAddingData } + def areBlocksBeingGenerated: Boolean = synchronized { + state != StoppedGeneratingBlocks + } + try { - while (isGeneratingBlocks) { + // While blocks are being generated, keep polling for to-be-pushed blocks and push them. + while (areBlocksBeingGenerated) { Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 363c03d431f0..deb15d075975 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + s"$batchTime is already added into InputInfoTracker, this is a illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f2117ada61c..2de035d166e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -79,6 +79,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { def start(): Unit = synchronized { if (eventLoop != null) return // generator has already been started + // Call checkpointWriter here to initialize it before eventLoop uses it to avoid a deadlock. + // See SPARK-10125 + checkpointWriter + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala index 882ca0676b6a..a46c0c1b25e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -76,9 +76,9 @@ private[streaming] abstract class RateController(val streamUID: Int, rateEstimat val elements = batchCompleted.batchInfo.streamIdToInputInfo for { - processingEnd <- batchCompleted.batchInfo.processingEndTime; - workDelay <- batchCompleted.batchInfo.processingDelay; - waitDelay <- batchCompleted.batchInfo.schedulingDelay; + processingEnd <- batchCompleted.batchInfo.processingEndTime + workDelay <- batchCompleted.batchInfo.processingDelay + waitDelay <- batchCompleted.batchInfo.schedulingDelay elems <- elements.get(streamUID).map(_.numRecords) } computeAndPublish(processingEnd, elems, workDelay, waitDelay) } @@ -86,5 +86,5 @@ private[streaming] abstract class RateController(val streamUID: Int, rateEstimat object RateController { def isBackPressureEnabled(conf: SparkConf): Boolean = - conf.getBoolean("spark.streaming.backpressure.enable", false) + conf.getBoolean("spark.streaming.backpressure.enabled", false) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 7720259a5d79..53b96d51c918 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -199,7 +199,8 @@ private[streaming] class ReceivedBlockTracker( import scala.collection.JavaConversions._ writeAheadLog.readAll().foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + Utils.deserialize[ReceivedBlockTrackerLogEvent]( + byteBuffer.array, Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index e076fb5ea174..aae3acf7aba3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -468,8 +468,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * Start a receiver along with its scheduled executors */ private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + def shouldStartReceiver: Boolean = { + // It's okay to start when trackerState is Initialized or Started + !(isTrackerStopping || isTrackerStopped) + } + val receiverId = receiver.streamId - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) return } @@ -494,14 +499,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // We will keep restarting the receiver job until ReceiverTracker is stopped future.onComplete { case Success(_) => - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) } else { logInfo(s"Restarting Receiver $receiverId") self.send(RestartReceiver(receiver)) } case Failure(e) => - if (!isTrackerStarted) { + if (!shouldStartReceiver) { onReceiverJobFinish(receiverId) } else { logError("Receiver has been stopped. Try to restart it.", e) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index 6ae56a68ad88..84a3ca9d74e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.scheduler.rate +import org.apache.spark.Logging + /** * Implements a proportional-integral-derivative (PID) controller which acts on * the speed of ingestion of elements into Spark Streaming. A PID controller works @@ -26,7 +28,7 @@ package org.apache.spark.streaming.scheduler.rate * * @see https://en.wikipedia.org/wiki/PID_controller * - * @param batchDurationMillis the batch duration, in milliseconds + * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current * error. This term usually provides the bulk of correction and should be positive or zero. * A value too large would make the controller overshoot the setpoint, while a small value @@ -39,13 +41,17 @@ package org.apache.spark.streaming.scheduler.rate * of future errors, based on current rate of change. This value should be positive or 0. * This term is not used very often, as it impacts stability of the system. The default * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. */ private[streaming] class PIDRateEstimator( batchIntervalMillis: Long, - proportional: Double = 1D, - integral: Double = .2D, - derivative: Double = 0D) - extends RateEstimator { + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { private var firstRun: Boolean = true private var latestTime: Long = -1L @@ -64,16 +70,23 @@ private[streaming] class PIDRateEstimator( require( derivative >= 0, s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") - def compute(time: Long, // in milliseconds + def compute( + time: Long, // in milliseconds numElements: Long, processingDelay: Long, // in milliseconds schedulingDelay: Long // in milliseconds ): Option[Double] = { - + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") this.synchronized { - if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + if (time > latestTime && numElements > 0 && processingDelay > 0) { // in seconds, should be close to batchDuration val delaySinceUpdate = (time - latestTime).toDouble / 1000 @@ -104,21 +117,30 @@ private[streaming] class PIDRateEstimator( val newRate = (latestRate - proportional * error - integral * historicalError - - derivative * dError).max(0.0) + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + latestTime = time if (firstRun) { latestRate = processingRate latestError = 0D firstRun = false - + logTrace("First run, rate estimation skipped") None } else { latestRate = newRate latestError = error - + logTrace(s"New rate = $newRate") Some(newRate) } - } else None + } else { + logTrace("Rate estimation skipped") + None + } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 17ccebc1ed41..d7210f64fcc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf -import org.apache.spark.SparkException import org.apache.spark.streaming.Duration /** @@ -61,7 +60,8 @@ object RateEstimator { val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) - new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 0c4c06534a69..e82c2fa4e72a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -17,25 +17,32 @@ package org.apache.spark.streaming -import org.apache.spark.Logging +import java.io.File + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkFunSuite, Logging} import org.apache.spark.util.Utils /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with Logging { +class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val directory = Utils.createTempDir() - val numBatches = 30 + private val batchDuration: Duration = Milliseconds(1000) + private val numBatches = 30 + private var directory: File = null - override def batchDuration: Duration = Milliseconds(1000) - - override def useManualClock: Boolean = false + before { + directory = Utils.createTempDir() + } - override def afterFunction() { - Utils.deleteRecursively(directory) - super.afterFunction() + after { + if (directory != null) { + Utils.deleteRecursively(directory) + } + StreamingContext.getActive().foreach { _.stop() } } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala index 97c32d8f2d59..a1af95be81c8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -36,72 +36,89 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { test("estimator checks ranges") { intercept[IllegalArgumentException] { - new PIDRateEstimator(0, 1, 2, 3) + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, -1, 2, 3) + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, -1, 3) + new PIDRateEstimator(100, 0, integral = -1, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, 0, -1) + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) } } - private def createDefaultEstimator: PIDRateEstimator = { - new PIDRateEstimator(20, 1D, 0D, 0D) - } - - test("first bound is None") { - val p = createDefaultEstimator + test("first estimate is None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) should equal(None) } - test("second bound is rate") { - val p = createDefaultEstimator + test("second estimate is not None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) // 1000 elements / s p.compute(10, 10, 10, 0) should equal(Some(1000)) } - test("works even with no time between updates") { - val p = createDefaultEstimator + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) - p.compute(10, 10, 10, 0) - p.compute(10, 10, 10, 0) should equal(None) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) } - test("bound is never negative") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing // this might point the estimator to try and decrease the bound, but we test it never - // goes below zero, which would be nonsensical. + // goes below the min rate, which would be nonsensical. val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.fill(50)(0) // no processing + val elements = List.fill(50)(1) // no processing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(100) // strictly positive accumulation val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.fill(49)(Some(0D))) + res.tail should equal(List.fill(49)(Some(minRate))) } test("with no accumulated or positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an increasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.tabulate(50)(x => x * 20) // increasing + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(0) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) } test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an decreasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term, @@ -116,13 +133,14 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { } test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { - val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) val times = List.tabulate(50)(x => x * 20) // every 20ms val rng = new Random() - val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) val procDelayMs = 20 val proc = List.fill(50)(procDelayMs) // 20ms of processing - val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) @@ -131,7 +149,12 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { res(n) should not be None if (res(n).get > 0 && sched(n) > 0) { res(n).get should be < speeds(n) + res(n).get should be >= minRate } } } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java new file mode 100644 index 000000000000..1c16da982923 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java deleted file mode 100644 index b2de2a2590f0..000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe; - -import java.lang.reflect.Field; -import java.math.BigInteger; - -import sun.misc.Unsafe; - -public final class PlatformDependent { - - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { - - private UNSAFE() { } - - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } - - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } - - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } - - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } - - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } - - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } - - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } - - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } - - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } - - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } - - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } - - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } - - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } - - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } - - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); - } - - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); - } - - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } - - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } - - } - - private static final Unsafe _UNSAFE; - - public static final int BYTE_ARRAY_OFFSET; - - public static final int INT_ARRAY_OFFSET; - - public static final int LONG_ARRAY_OFFSET; - - public static final int DOUBLE_ARRAY_OFFSET; - - // Support for resetting final fields while deserializing - public static final long BIG_INTEGER_SIGNUM_OFFSET; - public static final long BIG_INTEGER_MAG_OFFSET; - - /** - * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to - * allow safepoint polling during a large copy. - */ - private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; - - static { - sun.misc.Unsafe unsafe; - try { - Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); - unsafeField.setAccessible(true); - unsafe = (sun.misc.Unsafe) unsafeField.get(null); - } catch (Throwable cause) { - unsafe = null; - } - _UNSAFE = unsafe; - - if (_UNSAFE != null) { - BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); - INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); - LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); - DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); - - long signumOffset = 0; - long magOffset = 0; - try { - signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum")); - magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag")); - } catch (Exception ex) { - // should not happen - } - BIG_INTEGER_SIGNUM_OFFSET = signumOffset; - BIG_INTEGER_MAG_OFFSET = magOffset; - } else { - BYTE_ARRAY_OFFSET = 0; - INT_ARRAY_OFFSET = 0; - LONG_ARRAY_OFFSET = 0; - DOUBLE_ARRAY_OFFSET = 0; - BIG_INTEGER_SIGNUM_OFFSET = 0; - BIG_INTEGER_MAG_OFFSET = 0; - } - } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 70b81ce015dd..cf42877bf9fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import static org.apache.spark.unsafe.PlatformDependent.*; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -45,20 +45,18 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, - long leftOffset, - Object rightBase, - long rightOffset, - final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { - if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { return false; } i += 8; } while (i < length) { - if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { return false; } i += 1; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb..74105050e419 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -64,7 +64,7 @@ public long size() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +73,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e6..7857bf66a72a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Methods for working with fixed-size uncompressed bitsets. @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); } /** @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); } /** @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + final long word = Platform.getLong(baseObject, wordOffset); return (word & mask) != 0; } @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { - if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { + if (Platform.getLong(baseObject, addr) != 0) { return true; } } @@ -109,8 +109,7 @@ public static int nextSetBit( // Try to find the next set bit in the current word final int subIndex = fromIndex & 0x3f; - long word = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; if (word != 0) { return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); } @@ -118,7 +117,7 @@ public static int nextSetBit( // Find the next set bit in the rest of the words wi += 1; while (wi < bitsetSizeInWords) { - word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); if (word != 0) { return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 61f483ced321..4276f25c2165 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,7 +53,7 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index 91be46ba21ff..dd7582083437 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -19,7 +19,7 @@ import javax.annotation.Nullable; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. @@ -50,6 +50,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 358bb3725015..97b2c93f0dc3 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -60,7 +60,7 @@ public class TaskMemoryManager { /** * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is - * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page * size is limited by the maximum amount of data that can be stored in a long[] array, which is * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. */ @@ -144,14 +144,16 @@ public MemoryBlock allocatePage(long size) { public void freePage(MemoryBlock page) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - executorMemoryManager.free(page); + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { allocatedPages.clear(page.pageNumber); } - pageTable[page.pageNumber] = null; if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } + // Cannot access a page once it's freed. + executorMemoryManager.free(page); } /** @@ -166,7 +168,9 @@ public void freePage(MemoryBlock page) { public MemoryBlock allocate(long size) throws OutOfMemoryError { assert(size > 0) : "Size must be positive, but got " + size; final MemoryBlock memory = executorMemoryManager.allocate(size); - allocatedNonPageMemory.add(memory); + synchronized(allocatedNonPageMemory) { + allocatedNonPageMemory.add(memory); + } return memory; } @@ -176,8 +180,10 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; executorMemoryManager.free(memory); - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + synchronized(allocatedNonPageMemory) { + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } } /** @@ -223,9 +229,10 @@ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final Object page = pageTable[pageNumber].getBaseObject(); + final MemoryBlock page = pageTable[pageNumber]; assert (page != null); - return page; + assert (page.getBaseObject() != null); + return page.getBaseObject(); } else { return null; } @@ -244,7 +251,9 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { // converted the absolute address into a relative address. Here, we invert that operation: final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - return pageTable[pageNumber].getBaseOffset() + offsetInPage; + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; } } @@ -260,14 +269,17 @@ public long cleanUpAllAllocatedMemory() { freePage(page); } } - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); + + synchronized (allocatedNonPageMemory) { + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } } return freedBytes; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 62f4459696c2..cda7826c8c99 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.memory; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. @@ -29,7 +29,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } - long address = PlatformDependent.UNSAFE.allocateMemory(size); + long address = Platform.allocateMemory(size); return new MemoryBlock(null, address, size); } @@ -37,6 +37,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - PlatformDependent.UNSAFE.freeMemory(memory.offset); + Platform.freeMemory(memory.offset); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 69b0e206cef1..c08c9c73d239 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class ByteArray { @@ -27,12 +27,6 @@ public class ByteArray { * hold all the bytes in this string. */ public static void writeToMemory(byte[] src, Object target, long targetOffset) { - PlatformDependent.copyMemory( - src, - PlatformDependent.BYTE_ARRAY_OFFSET, - target, - targetOffset, - src.length - ); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d1014426c0f4..cbcab958c05a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -18,16 +18,15 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; -import java.io.Serializable; -import java.io.UnsupportedEncodingException; +import java.io.*; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import static org.apache.spark.unsafe.PlatformDependent.*; +import static org.apache.spark.unsafe.Platform.*; /** @@ -38,12 +37,13 @@ *

    * Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Serializable { +public final class UTF8String implements Comparable, Externalizable { + // These are only updated by readExternal() @Nonnull - private final Object base; - private final long offset; - private final int numBytes; + private Object base; + private long offset; + private int numBytes; public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -127,19 +127,18 @@ protected UTF8String(Object base, long offset, int numBytes) { this.numBytes = numBytes; } + // for serialization + public UTF8String() { + this(null, 0, 0); + } + /** * Writes the content of this string into a memory address, identified by an object and an offset. * The target memory address must already been allocated, and have enough space to hold all the * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - base, - offset, - target, - targetOffset, - numBytes - ); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } /** @@ -183,12 +182,12 @@ public long getPrefix() { long mask = 0; if (isLittleEndian) { if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -197,12 +196,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -293,7 +292,7 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return UNSAFE.getByte(base, offset + i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { @@ -769,7 +768,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { int len = inputs[i].numBytes; copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -778,7 +777,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { if (j < numInputs) { copyMemory( separator.base, separator.offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; } @@ -984,4 +983,18 @@ public UTF8String soundex() { } return UTF8String.fromBytes(sx); } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); + } + } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229..2f8cb132ac8b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,7 +22,7 @@ import java.util.Set; import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.junit.Test; /** @@ -83,11 +83,11 @@ public void randomizedStressTestBytes() { rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -106,11 +106,11 @@ public void randomizedStressTestPaddedStrings() { System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. diff --git a/yarn/pom.xml b/yarn/pom.xml index 49360c48256e..f6737695307a 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -38,6 +38,12 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-network-yarn_${scala.binary.version} + ${project.version} + test + org.apache.spark spark-core_${scala.binary.version} @@ -92,12 +98,28 @@ jetty-servlet - + + + + org.eclipse.jetty.orbit + javax.servlet.jsp + 2.2.0.v201112011158 + test + + + org.eclipse.jetty.orbit + javax.servlet.jsp.jstl + 1.2.0.v201105211821 + test + + - + org.apache.hadoop hadoop-yarn-server-tests @@ -137,7 +159,7 @@ test - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1d67b3ebb51b..991b5cec00bd 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,8 +30,8 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, + SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} @@ -64,7 +64,8 @@ private[spark] class ApplicationMaster( // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -111,7 +112,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -198,7 +200,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code @@ -493,7 +495,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => @@ -529,6 +530,10 @@ private[spark] class ApplicationMaster( e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class + case SparkUserAppException(exitCode) => + val msg = s"User application exited with status $exitCode" + logError(msg) + finish(FinalApplicationStatus.FAILED, exitCode, msg) case cause: Throwable => logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 37f793763367..b08412414aa1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b4ba3f022160..c5877b6fc0d8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -163,6 +163,23 @@ private[spark] class Client( appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") + sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS) + .map(StringUtils.getTrimmedStringCollection(_)) + .filter(!_.isEmpty()) + .foreach { tagCollection => + try { + // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use + // reflection to set it, printing a warning if a tag was specified but the YARN version + // doesn't support it. + val method = appContext.getClass().getMethod( + "setApplicationTags", classOf[java.util.Set[String]]) + method.invoke(appContext, new java.util.HashSet[String](tagCollection)) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " + + "YARN does not support it") + } + } sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match { case Some(v) => appContext.setMaxAppAttempts(v) case None => logDebug("spark.yarn.maxAppAttempts is not set. " + @@ -268,8 +285,8 @@ private[spark] class Client( // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokenForHiveMetastore(hadoopConf, credentials) - obtainTokenForHBase(hadoopConf, credentials) + obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) + obtainTokenForHBase(sparkConf, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -414,7 +431,7 @@ private[spark] class Client( } // Distribute an archive with Hadoop and Spark configuration for the AM. - val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), resType = LocalResourceType.ARCHIVE, destName = Some(LOCALIZED_CONF_DIR), appMasterOnly = true) @@ -751,7 +768,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -951,7 +967,7 @@ object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { logWarning("WARNING: This client is deprecated and will be removed in a " + - "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") + "future version of Spark. Use ./bin/spark-submit with \"--master yarn --deploy-mode cluster (or client) OR --master yarn-cluster (yarn-client)\"") } // Set an env variable indicating we are running in YARN mode. @@ -960,6 +976,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } @@ -984,6 +1004,10 @@ object Client extends Logging { // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" + // Comma-separated list of strings to pass through as YARN application tags appearing + // in YARN ApplicationReports, which can be used for filtering when querying YARN. + val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags" + // Staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) @@ -1074,20 +1098,10 @@ object Client extends Logging { triedDefault.toOption } - /** - * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. - * So we need to use reflection to retrieve it. - */ private[yarn] def getDefaultMRApplicationClasspath: Option[Seq[String]] = { val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - val value = if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray - } else { - field.get(null).asInstanceOf[Array[String]] - } - value.toSeq + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toSeq } recoverWith { case e: NoSuchFieldException => Success(Seq.empty[String]) } @@ -1225,8 +1239,11 @@ object Client extends Logging { /** * Obtains token for the Hive metastore and adds them to the credentials. */ - private def obtainTokenForHiveMetastore(conf: Configuration, credentials: Credentials) { - if (UserGroupInformation.isSecurityEnabled) { + private def obtainTokenForHiveMetastore( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials) { + if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { @@ -1283,8 +1300,11 @@ object Client extends Logging { /** * Obtain security token for HBase. */ - def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled) { + def obtainTokenForHBase( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials): Unit = { + if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { @@ -1380,4 +1400,13 @@ object Client extends Logging { components.mkString(Path.SEPARATOR) } + /** + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ + def shouldGetTokens(conf: SparkConf, service: String): Boolean = { + conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 20d63d40cf60..4f42ffefa77f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -53,8 +53,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -196,11 +195,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 52580deb372c..4cc50483a17f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -217,7 +217,7 @@ class ExecutorRunnable( // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - "-XX:OnOutOfMemoryError='kill %p'") ++ + YarnSparkHadoopUtil.getOutOfMemoryErrorArgument) ++ javaOpts ++ Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", "--driver-url", masterAddress.toString, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 59caa787b6e2..ccf753e69f4b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,6 +21,8 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern +import org.apache.spark.util.Utils + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -86,7 +88,12 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } // Keep track of which container is running which executor to remove the executors later // Visible for testing. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 68d01c17ef72..445d3dcd266d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -37,6 +37,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, P import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils @@ -219,26 +220,61 @@ object YarnSparkHadoopUtil { } } + /** + * The handler if an OOM Exception is thrown by the JVM must be configured on Windows + * differently: the 'taskkill' command should be used, whereas Unix-based systems use 'kill'. + * + * As the JVM interprets both %p and %%p as the same, we can use either of them. However, + * some tests on Windows computers suggest, that the JVM only accepts '%%p'. + * + * Furthermore, the behavior of the character '%' on the Windows command line differs from + * the behavior of '%' in a .cmd file: it gets interpreted as an incomplete environment + * variable. Windows .cmd files escape a '%' by '%%'. Thus, the correct way of writing + * '%%p' in an escaped way is '%%%%p'. + * + * @return The correct OOM Error handler JVM option, platform dependent. + */ + def getOutOfMemoryErrorArgument : String = { + if (Utils.isWindows) { + escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") + } else { + "-XX:OnOutOfMemoryError='kill %p'" + } + } + /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands - * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The - * argument is enclosed in single quotes and some key characters are escaped. + * using either + * + * (Unix-based) `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. + * The argument is enclosed in single quotes and some key characters are escaped. + * + * (Windows-based) part of a .cmd file in which case windows escaping for each argument must be + * applied. Windows is quite lenient, however it is usually Java that causes trouble, needing to + * distinguish between arguments starting with '-' and class names. If arguments are surrounded + * by ' java takes the following string as is, hence an argument is mistakenly taken as a class + * name which happens to start with a '-'. The way to avoid this, is to surround nothing with + * a ', but instead with a ". * * @param arg A single argument. * @return Argument quoted for execution via Yarn's generated shell script. */ def escapeForShell(arg: String): String = { if (arg != null) { - val escaped = new StringBuilder("'") - for (i <- 0 to arg.length() - 1) { - arg.charAt(i) match { - case '$' => escaped.append("\\$") - case '"' => escaped.append("\\\"") - case '\'' => escaped.append("'\\''") - case c => escaped.append(c) + if (Utils.isWindows) { + YarnCommandBuilderUtils.quoteForBatchScript(arg) + } else { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } } + escaped.append("'").toString() } - escaped.append("'").toString() } else { arg } diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala new file mode 100644 index 000000000000..3ac36ef0a1c3 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +/** + * Exposes needed methods + */ +private[spark] object YarnCommandBuilderUtils { + def quoteForBatchScript(arg: String) : String = { + CommandBuilderUtils.quoteForBatchScript(arg) + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d225061fcd1b..d06d95140438 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -81,8 +81,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -92,7 +90,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 000000000000..128e996b71fe --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ + + + def yarnConfig: YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConfig) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + protected def checkResult(result: File, expected: String): Unit = { + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 837f8d3fa55a..0a5402c89e76 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -29,8 +29,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.Records import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} @@ -170,6 +173,39 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { cp should contain ("/remotePath/my1.jar") } + test("configuration and args propagate through createApplicationSubmissionContext") { + val conf = new Configuration() + // When parsing tags, duplicates and leading/trailing whitespace should be removed. + // Spaces between non-comma strings should be preserved as single tags. Empty strings may or + // may not be removed depending on the version of Hadoop being used. + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup") + .set("spark.yarn.maxAppAttempts", "42") + val args = new ClientArguments(Array( + "--name", "foo-test-app", + "--queue", "staging-queue"), sparkConf) + + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) + val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext]) + + val client = new Client(args, conf, sparkConf) + client.createApplicationSubmissionContext( + new YarnClientApplication(getNewApplicationResponse, appContext), + containerLaunchContext) + + appContext.getApplicationName should be ("foo-test-app") + appContext.getQueue should be ("staging-queue") + appContext.getAMContainerSpec should be (containerLaunchContext) + appContext.getApplicationType should be ("SPARK") + appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => + val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] + tags should contain allOf ("tag1", "dup", "tag2", "multi word") + tags.filter(!_.isEmpty).size should be (4) + } + appContext.getMaxAppAttempts should be (42) + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 58318bf9bcc0..5d05f514adde 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -87,16 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index eb6e1fd37062..128350b64899 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,25 +17,20 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.File import java.net.URL -import java.util.Properties -import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.JavaConversions._ import com.google.common.base.Charsets.UTF_8 -import com.google.common.io.ByteStreams -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers import org.apache.spark._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, - SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -43,17 +38,9 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { - - // log4j configuration for the YARN containers, so that their output is collected - // by YARN instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin +class YarnClusterSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -82,65 +69,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | return 42 """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ - private var logConfDir: File = _ - - override def beforeAll() { - super.beforeAll() - - tempDir = Utils.createTempDir() - logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } - - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() - } - test("run Spark in yarn-client mode") { testBasicYarnApp(true) } @@ -174,7 +102,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } private def testBasicYarnApp(clientMode: Boolean): Unit = { - var result = File.createTempFile("result", null, tempDir) + val result = File.createTempFile("result", null, tempDir) runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) checkResult(result) @@ -224,89 +152,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(executorResult, "OVERRIDDEN") } - private def runSpark( - clientMode: Boolean, - klass: String, - appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, - extraClassPath: Seq[String] = Nil, - extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() - - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) - - // SPARK-4267: make sure java options are propagated correctly. - props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") - props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - - yarnCluster.getConfig().foreach { e => - props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) - } - - sys.props.foreach { case (k, v) => - if (k.startsWith("spark.")) { - props.setProperty(k, v) - } - } - - extraConf.foreach { case (k, v) => props.setProperty(k, v) } - - val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) - props.store(writer, "Spark properties.") - writer.close() - - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - private def checkResult(result: File, expected: String): Unit = { - var resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - private def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") - } - } private[spark] class SaveExecutorInfo extends SparkListener { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 000000000000..5e8238822b90 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 000000000000..aa46ec5100f0 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import com.google.common.annotations.VisibleForTesting +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 000000000000..2f22cbdbeac3 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala new file mode 100644 index 000000000000..db322cd18e15 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.File + +/** + * just a cheat to get package-visible members in tests + */ +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } + + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } + + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } + +}