diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a00238b41d60..c3c1893487334 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -271,7 +271,7 @@ setMethod("names<-", signature(x = "DataFrame"), function(x, value) { if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) dataFrame(sdf) } }) @@ -843,10 +843,10 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) @@ -1079,7 +1079,7 @@ setMethod("subset", signature(x = "DataFrame"), #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "select", col, list(...)) dataFrame(sdf) }) @@ -1090,7 +1090,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) @@ -1106,7 +1106,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -1133,7 +1133,7 @@ setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -1311,12 +1311,12 @@ setMethod("arrange", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col, ...) { if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) + sdf <- callJMethod(x@sdf, "sort", col, list(...)) } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "sort", jcols) } dataFrame(sdf) }) @@ -1664,7 +1664,7 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1674,7 +1674,7 @@ setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1731,7 +1731,7 @@ setMethod("dropna", naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) @@ -1815,7 +1815,7 @@ setMethod("fillna", sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bc6445311473..4ac057d0f2d83 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -49,7 +49,7 @@ infer_type <- function(x) { stopifnot(length(x) > 0) names <- names(x) if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { # StructType types <- lapply(x, infer_type) @@ -59,7 +59,7 @@ infer_type <- function(x) { do.call(structType, fields) } } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { type } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 4805096f3f9c5..42e9d12179db7 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -211,8 +211,7 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - table <- listToSeq(as.list(table)) - jc <- callJMethod(x@jc, "in", table) + jc <- callJMethod(x@jc, "in", as.list(table)) return(column(jc)) }) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 6cf628e3007de..d1858ec227b56 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -57,8 +57,10 @@ readTypedObject <- function(con, type) { readString <- function(con) { stringLen <- readInt(con) - string <- readBin(con, raw(), stringLen, endian = "big") - rawToChar(string) + raw <- readBin(con, raw(), stringLen, endian = "big") + string <- rawToChar(raw) + Encoding(string) <- "UTF-8" + string } readInt <- function(con) { diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d848730e70433..94687edb05442 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1331,7 +1331,7 @@ setMethod("countDistinct", x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) + jcol) column(jc) }) @@ -1348,7 +1348,7 @@ setMethod("concat", signature(x = "Column"), function(x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1366,7 +1366,7 @@ setMethod("greatest", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1384,7 +1384,7 @@ setMethod("least", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @export setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jcols <- lapply(list(x, ...), function(x) { x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) column(jc) }) @@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"), #' @export setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "format_string", format, jcols) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 576ac72f40fc0..4cab1a69f601a 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -102,7 +102,7 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } @@ -124,7 +124,7 @@ createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 79c744ef29c23..62d4f73878d29 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -114,6 +114,35 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + primtiveTypes <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + if (type %in% primtiveTypes) { + return() + } else { + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index e3676f57f907f..91e6b3e5609b5 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -79,7 +79,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big") + writeBin(utfVal, con, endian = "big", useBytes=TRUE) } writeInt <- function(con, value) { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3babcb519378e..69a2bc728f842 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -361,16 +361,6 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a # user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0da5e38654732..98d4402d368e1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -49,6 +49,14 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesNa, jsonPathNa) +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -56,10 +64,8 @@ test_that("infer types", { expect_equal(infer_type(TRUE), "boolean") expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") testStruct <- infer_type(list(a = 1L, b = "2")) expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) @@ -236,8 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and struct", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -247,7 +252,32 @@ test_that("create DataFrame with different data types", { # expect_equal(count(df), 1) # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) -#}) + + + # ArrayType only for now + l <- list(as.list(1:10), list("a", "b")) + df <- createDataFrame(sqlContext, list(l), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) +}) + +test_that("Collect DataFrame with complex types", { + # only ArrayType now + # TODO: tests for StructType and MapType after they are supported + df <- jsonFile(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) +}) test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) @@ -431,6 +461,32 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) +test_that("collect() support Unicode characters", { + markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s + } + + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"こんにちは\", \"age\":19}", + "{\"name\":\"Xin chào\"}") + + jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("こんにちは")) + expect_equal(rdf$name[4], markUtf8("Xin chào")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { @@ -1091,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") diff --git a/README.md b/README.md index 380422ca00dbe..76e29b4235666 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, and Python, and an optimized engine that +high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, @@ -94,5 +94,5 @@ distribution. ## Configuration -Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index ef0bb2ac13f08..8399033ac61ec 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -78,7 +79,7 @@ object Bagel extends Logging { val startTime = System.currentTimeMillis val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( + val combinedMsgs = msgs.combineByKeyWithClassTag( combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) val grouped = combinedMsgs.groupWith(verts) val superstep_ = superstep // Create a read-only copy of superstep for capture in closure @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/core/pom.xml b/core/pom.xml index 4f79d71bf85fa..a46292c13bcc0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -272,6 +272,14 @@ org.apache.hadoop hadoop-client + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework + org.apache.curator curator-recipes diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 3d1ef0c48adc5..e73ba39468828 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -122,6 +122,10 @@ public UnsafeShuffleExternalSorter( this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); + + // preserve first page to ensure that we have at least one page to work with. Otherwise, + // other operators in the same task may starve this sorter (SPARK-9709). + acquireNewPageIfNecessary(pageSizeBytes); } /** diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index c39c8667d013e..5592b75afb75b 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils * @tparam T partial data that can be added in */ class Accumulable[R, T] private[spark] ( - @transient initialValue: R, + initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fc8cdde9348ee..9aafc9eb1cde7 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -65,8 +67,8 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) */ @DeveloperApi -class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], +class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, @@ -76,6 +78,13 @@ class ShuffleDependency[K, V, C]( override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] + private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName + private[spark] val valueClassName: String = reflect.classTag[V].runtimeClass.getName + // Note: It's possible that the combiner class tag is null, if the combineByKey + // methods in PairRDDFunctions are used instead of combineByKeyWithClassTag. + private[spark] val combinerClassName: Option[String] = + Option(reflect.classTag[C]).map(_.runtimeClass.getName) + val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 29e581bb57cbc..e4df7af81a6d2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -104,8 +104,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K, V]], + partitions: Int, + rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 673ef49e7c1c5..746d2081d4393 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -310,7 +310,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } - def getViewAcls: String = viewAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getViewAcls: String = { + if (viewAcls.contains("*")) { + "*" + } else { + viewAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -321,7 +330,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } - def getModifyAcls: String = modifyAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getModifyAcls: String = { + if (modifyAcls.contains("*")) { + "*" + } else { + modifyAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -394,7 +412,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) + !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") } /** @@ -409,7 +427,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) + !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 738887076b0d1..e27b3c4962221 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -858,7 +858,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -910,7 +910,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1092,7 +1092,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1516,8 +1516,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + getRDDStorageInfo(_ => true) + } + + private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0f1e2e069568d..c6fef7f91f00c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -33,7 +33,6 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} @@ -326,15 +325,7 @@ object SparkEnv extends Logging { val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) - val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { - case "netty" => - new NettyBlockTransferService(conf, securityManager, numUsableCores) - case "nio" => - logWarning("NIO-based block transfer service is deprecated, " + - "and will be removed in Spark 1.6.0.") - new NioBlockTransferService(conf, securityManager) - } + val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f5dd36cbcfe6d..ae5926dd534a6 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(@transient jobConf: JobConf) +class SparkHadoopWriter(jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 934d00dc708b9..2ae878b3e6087 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -48,6 +48,8 @@ case object Success extends TaskEndReason sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String + + def shouldEventuallyFailJob: Boolean = true } /** @@ -194,6 +196,12 @@ case object TaskKilled extends TaskFailedReason { case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + /** + * If a task failed because its attempt to commit was denied, do not count this failure + * towards failing the stage. This is intended to prevent spurious stage failures in cases + * where many speculative tasks are launched and denied to commit. + */ + override def shouldEventuallyFailJob: Boolean = false } /** @@ -202,8 +210,14 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String) extends TaskFailedReason { - override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" +case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false) + extends TaskFailedReason { + override def toErrorString: String = { + val exitBehavior = if (isNormalExit) "normally" else "abnormally" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + } + + override def shouldEventuallyFailJob: Boolean = !isNormalExit } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index fb787979c1820..8344f6368ac48 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -239,7 +239,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapSideCombine: Boolean, serializer: Serializer): JavaPairRDD[K, C] = { implicit val ctag: ClassTag[C] = fakeClassTag - fromRDD(rdd.combineByKey( + fromRDD(rdd.combineByKeyWithClassTag( createCombiner, mergeValue, mergeCombiners, diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b4d152b336602..69da180593bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal private[spark] class PythonRDD( - @transient parent: RDD[_], + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -785,7 +785,7 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index bb82f3285f1d9..2a792d81994fd 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) - val obj = ctor.newInstance(args : _*) + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } + + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 26ad4f1d4697e..3c92bb7a1c73c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R @@ -213,89 +214,97 @@ private[spark] object SerDe { } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.Character" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => writeType(dos, "character") - writeString(dos, value.asInstanceOf[Character].toString) - case "java.lang.String" => + writeString(dos, v.toString) + case v: java.lang.String => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "java.lang.Long" => + writeString(dos, v) + case v: java.lang.Long => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "java.lang.Float" => + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "java.math.BigDecimal" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => writeType(dos, "double") - val javaDecimal = value.asInstanceOf[java.math.BigDecimal] - writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "java.lang.Double" => + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "java.lang.Byte" => + writeDouble(dos, v) + case v: java.lang.Byte => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Byte].toInt) - case "java.lang.Short" => + writeInt(dos, v.toInt) + case v: java.lang.Short => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Short].toInt) - case "java.lang.Integer" => + writeInt(dos, v.toInt) + case v: java.lang.Integer => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "java.lang.Boolean" => + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) + writeTime(dos, v) // Handle arrays // Array of primitive types // Special handling for byte array - case "[B" => + case v: Array[Byte] => writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) + writeBytes(dos, v) - case "[C" => + case v: Array[Char] => writeType(dos, "array") - writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) - case "[S" => + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) - case "[I" => + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => + writeIntArr(dos, v) + case v: Array[Long] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[F" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) - case "[D" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => + writeDoubleArr(dos, v) + case v: Array[Boolean] => writeType(dos, "array") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + writeBooleanArr(dos, v) // Array of objects, null objects use "void" type - case c if c.startsWith("[") => + case v: Array[Object] => writeType(dos, "list") - val array = value.asInstanceOf[Array[Object]] - writeInt(dos, array.length) - array.foreach(elem => writeObject(dos, elem)) + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) case _ => writeType(dos, "jobj") @@ -329,12 +338,11 @@ private[spark] object SerDe { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } - // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { - val len = value.length - out.writeInt(len + 1) // For the \0 - out.writeBytes(value) - out.writeByte(0) + val utf8 = value.getBytes("UTF-8") + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) } def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f7723ef5bde4c..a0b7365df900a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -192,7 +192,9 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } @@ -204,7 +206,9 @@ class SparkHadoopUtil extends Logging { */ def getTaskAttemptIDFromTaskAttemptContext( context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 86fcf942c2c4e..ad92f5635af35 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -319,8 +319,8 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + case (MESOS, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + @@ -551,7 +551,15 @@ object SparkSubmit { if (isMesosCluster) { assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" - childArgs += (args.primaryResource, args.mainClass) + if (args.isPython) { + // Second argument is main class + childArgs += (args.primaryResource, "") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } + } else { + childArgs += (args.primaryResource, args.mainClass) + } if (args.childArgs != null) { childArgs ++= args.childArgs } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e573ff16c50a3..a5755eac36396 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -73,7 +74,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -179,15 +180,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -224,12 +224,29 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + fs.delete(path) + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 1fe956320a1b8..957a928bc402b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -392,15 +392,14 @@ private[spark] object RestSubmissionClient { mainClass: String, appArgs: Array[String], conf: SparkConf, - env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + env: Map[String, String] = Map()): SubmitRestProtocolResponse = { val master = conf.getOption("spark.master").getOrElse { throw new IllegalArgumentException("'spark.master' must be set.") } val sparkProperties = conf.getAll.toMap - val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( - appResource, mainClass, appArgs, sparkProperties, environmentVariables) + appResource, mainClass, appArgs, sparkProperties, env) client.createSubmission(submitRequest) } @@ -413,6 +412,16 @@ private[spark] object RestSubmissionClient { val mainClass = args(1) val appArgs = args.slice(2, args.size) val conf = new SparkConf - run(appResource, mainClass, appArgs, conf) + val env = filterSystemEnvironment(sys.env) + run(appResource, mainClass, appArgs, conf, env) + } + + /** + * Filter non-spark environment variables from any environment. + */ + private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { + env.filter { case (k, _) => + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + } } } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index a5ad47293f1c2..e2ffc3b64e5db 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -131,8 +131,8 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat */ @Experimental class PortableDataStream( - @transient isplit: CombineFileSplit, - @transient context: TaskAttemptContext, + isplit: CombineFileSplit, + context: TaskAttemptContext, index: Integer) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efca..9dc36704a676d 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -148,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7c170a742fb64..76968249fb625 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ff8aae9ebe9f0..70a42f9045e6b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { @@ -137,7 +137,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") - result.success() + result.success((): Unit) } override def onFailure(e: Throwable): Unit = { logError(s"Error while uploading block $blockId", e) @@ -149,7 +149,11 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } override def close(): Unit = { - server.close() - clientFactory.close() + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala deleted file mode 100644 index 79cb0640c8672..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -// private[spark] because we need to register them in Kryo -private[spark] case class GetBlock(id: BlockId) -private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) -private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) - -private[nio] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: BlockId = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = BlockId(idBuilder.toString) - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = typ - def getId: BlockId = id - def getData: ByteBuffer = data - def getLevel: StorageLevel = level - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) - buffer.putInt(typ).putInt(id.name.length) - id.name.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[nio] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala deleted file mode 100644 index f1c9ea8b64ca3..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark._ -import org.apache.spark.storage.{StorageLevel, TestBlockId} - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) - extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int): BlockMessage = blockMessages(i) - - def iterator: Iterator[BlockMessage] = blockMessages.iterator - - def length: Int = blockMessages.length - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + - (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - Message.createBufferMessage(buffers) - } -} - -private[nio] object BlockMessageArray extends Logging { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear() - BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, - StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - logDebug("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - logDebug("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - logDebug("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - logDebug("Converted back to block message array") - // scalastyle:off println - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - // scalastyle:on println - } -} - - diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala deleted file mode 100644 index 9a9e22b0c2366..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.BlockManager - - -private[nio] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size: Int = initialSize - - def currentSize(): Int = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - val security = if (isSecurityNeg) 1 else 0 - if (size == 0 && !gotChunkForSendingOnce) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, - hasError, security, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - val security = if (isSecurityNeg) 1 else 0 - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId(): Boolean = ackId != 0 - - def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - - override def toString: String = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala deleted file mode 100644 index 8d9ebadaf79d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ /dev/null @@ -1,619 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.LinkedList - -import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.control.NonFatal - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} - -private[nio] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, - val securityMgr: SecurityManager) - extends Logging { - - var sparkSaslServer: SparkSaslServer = null - var sparkSaslClient: SparkSaslClient = null - - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, - securityMgr_ : SecurityManager) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), - id_, securityMgr_) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /* channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def isSaslComplete(): Boolean - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - private def disposeSasl() { - if (sparkSaslServer != null) { - sparkSaslServer.dispose() - } - - if (sparkSaslClient != null) { - sparkSaslClient.dispose() - } - } - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key(): SelectionKey = channel.keyFor(selector) - - def getRemoteAddress(): InetSocketAddress = { - channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - } - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - disposeSasl() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Throwable) => Unit) { - onExceptionCallbacks.add(callback) - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks.asScala.foreach { - callback => - try { - callback(this, e) - } catch { - case NonFatal(e) => { - logWarning("Ignored error in onExceptionCallback", e) - } - } - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.length + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[nio] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslClient != null) sparkSaslClient.isComplete() else false - } - - private class Outbox { - val messages = new LinkedList[Message]() - val defaultChunkSize = 65536 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized { - messages.add(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /* val message = messages(nextMessageToBeUsed) */ - - val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { - // only allow sending of security messages until sasl is complete - var pos = 0 - var securityMsg: Message = null - while (pos < messages.size() && securityMsg == null) { - if (messages.get(pos).isSecurityNeg) { - securityMsg = messages.remove(pos) - } - pos = pos + 1 - } - // didn't find any security messages and auth isn't completed so return - if (securityMsg == null) return None - securityMsg - } else { - messages.removeFirst() - } - - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.add(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox() - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /* channel.socket.setSendBufferSize(256 * 1024) */ - - override def getRemoteAddress(): InetSocketAddress = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def registerAfterAuth(): Unit = { - outbox.synchronized { - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try { - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => - logError("Error connecting to " + address, e) - callOnExceptionCallbacks(e) - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallbacks(e) - } - } - true - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as - // normal registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /* key.interestOps(0) */ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), - e) - callOnExceptionCallbacks(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection( - channel_ : SocketChannel, - selector_ : Selector, - id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(channel_, selector_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslServer != null) sparkSaslServer.isComplete() else false - } - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = header.securityNeg == 1 - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The receiver's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection, Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /* logDebug("Read " + bytesRead + " bytes for the buffer") */ - - if (currentChunk.buffer.remaining == 0) { - /* println("Filled buffer at " + System.currentTimeMillis) */ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip() - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala deleted file mode 100644 index b3b281ff465f1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString: String = { - connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId - } -} - -private[nio] object ConnectionId { - - def createConnectionIdFromString(connectionIdString: String): ConnectionId = { - val res = connectionIdString.split("_").map(_.trim()) - if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + - " to a ConnectionId Object") - } - new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala deleted file mode 100644 index 9143918790381..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ /dev/null @@ -1,1157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.lang.ref.WeakReference -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.{ThreadUtils, Utils} - -import scala.util.Try -import scala.util.control.NonFatal - -private[nio] class ConnectionManager( - port: Int, - conf: SparkConf, - securityManager: SecurityManager, - name: String = "Connection manager") - extends Logging { - - /** - * Used by sendMessageReliably to track messages being sent. - * @param message the message that was sent - * @param connectionManagerId the connection manager that sent this message - * @param completionHandler callback that's invoked when the send has completed or failed - */ - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: Try[Message] => Unit) { - - def success(ackMessage: Message) { - if (ackMessage == null) { - failure(new NullPointerException) - } - else { - completionHandler(scala.util.Success(ackMessage)) - } - } - - def failWithoutAck() { - completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) - } - - def failure(e: Throwable) { - completionHandler(scala.util.Failure(e)) - } - } - - private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = - new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) - - private val ackTimeout = - conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", - conf.get("spark.network.timeout", "120s")) - - // Get the thread counts from the Spark Configuration. - // - // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, - // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is - // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" - // parameter is necessary. - private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) - private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4) - private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1) - - private val handleMessageExecutor = new ThreadPoolExecutor( - handlerThreadCount, - handlerThreadCount, - conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-message-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleMessageExecutor is not handled properly", t) - } - } - } - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - ioThreadCount, - ioThreadCount, - conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-read-write-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleReadWriteExecutor is not handled properly", t) - } - } - } - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : - // which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - connectThreadCount, - connectThreadCount, - conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-connect-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleConnectExecutor is not handled properly", t) - } - } - } - - private val serverChannel = ServerSocketChannel.open() - // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] - with SynchronizedMap[ConnectionId, SendingConnection] - private val connectionsByKey = - new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] - with SynchronizedMap[ConnectionManagerId, SendingConnection] - // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this - // map when messages are sent and are removed when acknowledgement messages are received or when - // acknowledgement timeouts expire - private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor( - ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) - - @volatile - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null - - private val authEnabled = securityManager.isAuthenticationEnabled() - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - private def startService(port: Int): (ServerSocketChannel, Int) = { - serverChannel.socket.bind(new InetSocketAddress(port)) - (serverChannel, serverChannel.socket.getLocalPort) - } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - // used in combination with the ConnectionManagerId to create unique Connection ids - // to be able to track asynchronous messages - private val idCount: AtomicInteger = new AtomicInteger(1) - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - @volatile private var isActive = true - private val selectorThread = new Thread("connection-manager-thread") { - override def run(): Unit = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - // start this thread last, since it invokes run(), which accesses members above - selectorThread.start() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in - // triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } catch { - case NonFatal(e) => { - logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallbacks(e) - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while (isActive) { - while (!registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue() - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue() - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + - intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions - // should be dealt with differently. - case e: CancelledKeyException => - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - 0 - - case e: ClosedSelectorException => - logDebug("Failed select() as selector is closed.", e) - return - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + - " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, - // key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, - securityManager) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - connection match { - case sendingConnection: SendingConnection => - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId - - messageStatuses.synchronized { - messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) - .foreach(status => { - logInfo("Notifying " + status) - status.failWithoutAck() - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case receivingConnection: ReceivingConnection => - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (!sendingConnectionOpt.isDefined) { - logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert(sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values - if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.failWithoutAck() - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case _ => logError("Unsupported type of connection.") - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Throwable) { - logInfo("Handling connection error on connection to " + - connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registrations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - try { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } catch { - case NonFatal(e) => { - logError("Error when handling messages from " + - connection.getRemoteConnectionManagerId(), e) - connection.callOnExceptionCallbacks(e) - } - } - } - } - handleMessageExecutor.execute(runnable) - /* handleMessage(connection, message) */ - } - - private def handleClientAuthentication( - waitingConn: SendingConnection, - securityMsg: SecurityMessage, - connectionId : ConnectionId) { - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } else { - var replyToken : Array[Byte] = null - try { - replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new IOException("Error creating security message") - sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => - logError("Error handling sasl client authentication", e) - waitingConn.close() - throw new IOException("Error evaluating sasl response: ", e) - } - } - } - - private def handleServerAuthentication( - connection: Connection, - securityMsg: SecurityMessage, - connectionId: ConnectionId) { - if (!connection.isSaslComplete()) { - logDebug("saslContext not established") - var replyToken : Array[Byte] = null - try { - connection.synchronized { - if (connection.sparkSaslServer == null) { - logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) - } - } - replyToken = connection.sparkSaslServer.response(securityMsg.getToken) - if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId + - " for: " + connectionId) - } else { - logDebug("Server sasl not completed: " + connection.connectionId + - " for: " + connectionId) - } - if (replyToken != null) { - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security Message") - sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } - } catch { - case e: Exception => { - logError("Error in server auth negotiation: " + e) - // It would probably be better to send an error message telling other side auth failed - // but for now just close - connection.close() - } - } - } else { - logDebug("connection already established for this connection id: " + connection.connectionId) - } - } - - - private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { - if (bufferMessage.isSecurityNeg) { - logDebug("This is security neg message") - - // parse as SecurityMessage - val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) - - connectionsAwaitingSasl.get(connectionId) match { - case Some(waitingConn) => { - // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthentication(waitingConn, securityMsg, connectionId) - } - case None => { - // Server - someone sent us something and we haven't authenticated yet - logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthentication(conn, securityMsg, connectionId) - } - } - return true - } else { - if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. - logError("message sent that is not security negotiation message on connection " + - "not authenticated yet, ignoring it!!") - return true - } - } - false - } - - private def handleMessage( - connectionManagerId: ConnectionManagerId, - message: Message, - connection: Connection) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (authEnabled) { - val res = handleAuthentication(connection, bufferMessage) - if (res) { - // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning") - return - } - } - if (bufferMessage.hasAckId()) { - messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status.success(message) - } - case None => { - /** - * We can fall down on this code because of following 2 cases - * - * (1) Invalid ack sent due to buggy code. - * - * (2) Late-arriving ack for a SendMessageStatus - * To avoid unwilling late-arriving ack - * caused by long pause like GC, you can set - * larger value than default to spark.core.connection.ack.wait.timeout - */ - logWarning(s"Could not find reference for received ack Message ${message.id}") - } - } - } - } else { - var ackMessage : Option[Message] = None - try { - ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - } catch { - case e: Exception => { - logError(s"Exception was thrown while processing message", e) - ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) - } - } finally { - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { - // see if we need to do sasl before writing - // this should only be the first negotiation as the Client!!! - if (!conn.isSaslComplete()) { - conn.synchronized { - if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) - var firstResponse: Array[Byte] = null - try { - firstResponse = conn.sparkSaslClient.firstToken() - val securityMsg = SecurityMessage.fromResponse(firstResponse, - conn.connectionId.toString()) - val message = securityMsg.toBufferMessage - if (message == null) throw new Exception("Error creating security message") - connectionsAwaitingSasl += ((conn.connectionId, conn)) - sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + - " to: " + connManagerId) - } catch { - case e: Exception => { - logError("Error getting first response from the SaslClient.", e) - conn.close() - throw new Exception("Error getting first response from the SaslClient") - } - } - } - } - } else { - logDebug("Sasl already established ") - } - } - - // allow us to add messages to the inbox for doing sasl negotiating - private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId, securityManager) - logInfo("creating new sending connection for security! " + newConnectionId ) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? - // We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - message.senderAddress = id.toSocketAddress() - logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") - val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - - // send security message until going connection has been authenticated - connection.send(message) - - wakeupSelector() - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, - connectionManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId, securityManager) - newConnection.onException { - case (conn, e) => { - logError("Exception while sending message.", e) - reportSendingMessageFailure(message.id, e) - } - } - logTrace("creating new sending connection: " + newConnectionId) - registerRequests.enqueue(newConnection) - - newConnection - } - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - - message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) - - if (authEnabled) { - try { - checkSendAuthFirst(connectionManagerId, connection) - } catch { - case NonFatal(e) => { - reportSendingMessageFailure(message.id, e) - } - } - } - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - wakeupSelector() - } - - private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(messageId) - s match { - case Some(msgStatus) => { - messageStatuses -= messageId - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.failure(e) - } - case None => { - logError("no messageStatus for failed message id: " + messageId) - } - } - } - } - - private def wakeupSelector() { - selector.wakeup() - } - - /** - * Send a message and block until an acknowledgment is received or an error occurs. - * @param connectionManagerId the message's destination - * @param message the message being sent - * @return a Future that either returns the acknowledgment message or captures an exception. - */ - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Message] = { - val promise = Promise[Message]() - - // It's important that the TimerTask doesn't capture a reference to `message`, which can cause - // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time - // at which they would originally be scheduled to run. Therefore, extract the message id - // from outside of the TimerTask closure (see SPARK-4393 for more context). - val messageId = message.id - // Keep a weak reference to the promise so that the completed promise may be garbage-collected - val promiseReference = new WeakReference(promise) - val timeoutTask: TimerTask = new TimerTask { - override def run(timeout: Timeout): Unit = { - messageStatuses.synchronized { - messageStatuses.remove(messageId).foreach { s => - val e = new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec") - val p = promiseReference.get - if (p != null) { - // Attempt to fail the promise with a Timeout exception - if (!p.tryFailure(e)) { - // If we reach here, then someone else has already signalled success or failure - // on this promise, so log a warning: - logError("Ignore error because promise is completed", e) - } - } else { - // The WeakReference was empty, which should never happen because - // sendMessageReliably's caller should have a strong reference to promise.future; - logError("Promise was garbage collected; this should never happen!", e) - } - } - } - } - } - - val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) - - val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTaskHandle.cancel() - s match { - case scala.util.Failure(e) => - // Indicates a failure where we either never sent or never got ACK'd - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - case scala.util.Success(ackMessage) => - if (ackMessage.hasError) { - val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head - val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) - errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, UTF_8) - val e = new IOException( - s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - } else { - if (!promise.trySuccess(ackMessage)) { - logWarning("Drop ackMessage because promise is completed") - } - } - } - }) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - - sendMessage(connectionManagerId, message) - promise.future - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - isActive = false - ackTimeoutMonitor.stop() - selector.close() - selectorThread.interrupt() - selectorThread.join() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - import scala.concurrent.ExecutionContext.Implicits.global - - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // scalastyle:off println - println("Received [" + msg + "] from [" + id + "]") - // scalastyle:on println - None - }) - - /* testSequentialSending(manager) */ - /* System.gc() */ - - /* testParallelSending(manager) */ - /* System.gc() */ - - /* testParallelDecreasingSending(manager) */ - /* System.gc() */ - - testContinuousSending(manager) - System.gc() - } - - // scalastyle:off println - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + - "(" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count) { i => - val bufferLen = size * (i + 1) - val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) - ByteBuffer.allocate(bufferLen).put(bufferContent) - } - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /* println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } - // scalastyle:on println -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala deleted file mode 100644 index 85d2fe2bf9c20..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 - -import org.apache.spark.util.Utils - -private[nio] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString: String = { - this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" - } -} - - -private[nio] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId(): Int = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - /** - * Create a "negative acknowledgment" to notify a sender that an error occurred - * while processing its message. The exception's stacktrace will be formatted - * as a string, serialized into a byte array, and sent as the message payload. - */ - def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { - val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) - val errorMessage = createBufferMessage(serializedExceptionString, ackId) - errorMessage.hasError = true - errorMessage - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala deleted file mode 100644 index 7b3da4bb9d5ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.{InetAddress, InetSocketAddress} -import java.nio.ByteBuffer - -private[nio] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString: String = { - "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - } - -} - - -private[nio] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala deleted file mode 100644 index b2aec160635c7..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} - -import scala.concurrent.Future - - -/** - * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom - * implementation using Java NIO. - */ -final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) - extends BlockTransferService with Logging { - - private var cm: ConnectionManager = _ - - private var blockDataManager: BlockDataManager = _ - - /** - * Port number the service is listening on, available only after [[init]] is invoked. - */ - override def port: Int = { - checkInit() - cm.id.port - } - - /** - * Host name the service is listening on, available only after [[init]] is invoked. - */ - override def hostName: String = { - checkInit() - cm.id.host - } - - /** - * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. - */ - override def init(blockDataManager: BlockDataManager): Unit = { - this.blockDataManager = blockDataManager - cm = new ConnectionManager( - conf.getInt("spark.blockManager.port", 0), - conf, - securityManager, - "Connection manager for block manager") - cm.onReceiveMessage(onBlockMessageReceive) - } - - /** - * Tear down the transfer service. - */ - override def close(): Unit = { - if (cm != null) { - cm.stop() - } - } - - override def fetchBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener): Unit = { - checkInit() - - val cmId = new ConnectionManagerId(host, port) - val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => - BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) - }) - - val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - - // Register the listener on success/failure future callback. - future.onSuccess { case message => - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - - // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. - if (blockMessageArray.isEmpty) { - blockIds.foreach { id => - listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) - } - } else { - for (blockMessage: BlockMessage <- blockMessageArray) { - val msgType = blockMessage.getType - if (msgType != BlockMessage.TYPE_GOT_BLOCK) { - if (blockMessage.getId != null) { - listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message $msgType received from $cmId")) - } - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioManagedBuffer(blockMessage.getData)) - } - } - } - }(cm.futureExecContext) - - future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) - } - }(cm.futureExecContext) - } - - /** - * Upload a single block to a remote node, available only after [[init]] is invoked. - * - * This call blocks until the upload completes, or throws an exception upon failures. - */ - override def uploadBlock( - hostname: String, - port: Int, - execId: String, - blockId: BlockId, - blockData: ManagedBuffer, - level: StorageLevel) - : Future[Unit] = { - checkInit() - val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) - val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) - val remoteCmId = new ConnectionManagerId(hostName, port) - val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) - reply.map(x => ())(cm.futureExecContext) - } - - private def checkInit(): Unit = if (cm == null) { - throw new IllegalStateException(getClass.getName + " has not been initialized") - } - - private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => - logError("Exception handling buffer message", e) - Some(Message.createErrorMessage(e, msg.id)) - } - - case otherMessage: Any => - val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" - logError(errorMsg) - Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) - } - } - - private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => - val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + msg + "]") - putBlock(msg.id, msg.data, msg.level) - None - - case BlockMessage.TYPE_GET_BLOCK => - val msg = new GetBlock(blockMessage.getId) - logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) - - case _ => None - } - } - - private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) - logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(blockId: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId) - logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer.nioByteBuffer() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala deleted file mode 100644 index 232c552f9865d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -import org.apache.spark._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[nio] class SecurityMessage extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[nio] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 1f755db485812..6fec00dcd0d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -28,7 +28,7 @@ private[spark] class BinaryFileRDD[T]( inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -36,10 +36,10 @@ private[spark] class BinaryFileRDD[T]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 922030263756b..fc1710fbad0a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @@ -64,7 +64,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds */ private[spark] def removeBlocks() { blockIds.foreach { blockId => - sc.env.blockManager.master.removeBlock(blockId) + sparkContext.env.blockManager.master.removeBlock(blockId) } _isValid = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index c1d6971787572..18e8cddbc40db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -27,8 +27,8 @@ import org.apache.spark.util.Utils private[spark] class CartesianPartition( idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], + @transient private val rdd1: RDD[_], + @transient private val rdd2: RDD[_], s1Index: Int, s2Index: Int ) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 72fe215dae73e..b0364623af4cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -29,7 +29,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * An RDD that recovers checkpointed data from storage. */ -private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) +private[spark] abstract class CheckpointRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { // CheckpointRDD should not be checkpointed again diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9c617fc719cb5..7bad749d58327 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -22,6 +22,7 @@ import scala.language.existentials import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -74,7 +75,9 @@ private[spark] class CoGroupPartition( * @param part partitioner used to partition the shuffle output */ @DeveloperApi -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) +class CoGroupedRDD[K: ClassTag]( + @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], + part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs). diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1f8719eead02..8f2655d63b797 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -109,7 +109,7 @@ class HadoopRDD[K, V]( extends RDD[(K, V)](sc, Nil) with Logging { if (initLocalJobConfFuncOpt.isDefined) { - sc.clean(initLocalJobConfFuncOpt.get) + sparkContext.clean(initLocalJobConfFuncOpt.get) } def this( @@ -137,7 +137,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala index daa5779d688cc..bfe19195fcd37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.storage.RDDBlockId * @param numPartitions the number of partitions in the checkpointed RDD */ private[spark] class LocalCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, rddId: Int, numPartitions: Int) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index d6fad896845f6..c115e0ff74d3c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * is written to the local, ephemeral block storage that lives in each executor. This is useful * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). */ -private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala index 1f2213d0c4346..417ff5278db2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -41,7 +41,7 @@ private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M // In certain join operations, prepare can be called on the same partition multiple times. // In this case, we need to ensure that each call to compute gets a separate prepare argument. - private[this] var preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] + private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] /** * Prepare a partition for a single call to compute. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6a9c004d65cff..174979aaeb231 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -40,7 +40,7 @@ import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -68,14 +68,14 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient private val _conf: Configuration) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) - // private val serializableConf = new SerializableWritable(conf) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) + // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -88,10 +88,10 @@ class NewHadoopRDD[K, V]( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -262,7 +262,7 @@ private[spark] class WholeTextFileRDD( inputFormatClass: Class[_ <: WholeTextFileInputFormat], keyClass: Class[String], valueClass: Class[String], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { @@ -270,10 +270,10 @@ private[spark] class WholeTextFileRDD( val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(getConf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(getConf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 4e5f2e8a5d467..199d79b811d65 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -57,7 +57,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) with SparkHadoopMapReduceUtil with Serializable { + /** + * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type @@ -70,12 +72,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: V => C, + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = self.withScope { + serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -103,13 +107,50 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] */ - def combineByKey[C](createCombiner: V => C, + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializer: Serializer = null): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + partitioner, mapSideCombine, serializer)(null) + } + + /** + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + * This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, numPartitions: Int): RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, numPartitions)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + */ + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + new HashPartitioner(numPartitions)) } /** @@ -133,7 +174,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) - combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) + combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), + cleanedSeqOp, combOp, partitioner) } /** @@ -182,7 +224,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) - combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) + combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), + cleanedFunc, cleanedFunc, partitioner) } /** @@ -268,7 +311,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { - combineByKey[V]((v: V) => v, func, func, partitioner) + combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** @@ -392,7 +435,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) h1 } - combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.cardinality()) + combineByKeyWithClassTag(createHLL, mergeValueHLL, mergeHLL, partitioner) + .mapValues(_.cardinality()) } /** @@ -466,7 +510,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createCombiner = (v: V) => CompactBuffer(v) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -565,12 +609,30 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. This method is here for backward compatibility. It + * does not provide combiner classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } /** @@ -934,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -1002,7 +1065,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e2394e28f8d26..582fa93afe34e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -83,8 +83,8 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } private[spark] class ParallelCollectionRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient data: Seq[T], + sc: SparkContext, + @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index a00f4c1cdff91..d6a37e8cc5dac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -32,7 +32,7 @@ private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Par * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -55,8 +55,8 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF */ @DeveloperApi class PartitionPruningRDD[T: ClassTag]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) + prev: RDD[T], + partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index a637d6f15b7e5..3b1acacf409b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -47,8 +47,8 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient preservesPartitioning: Boolean, - @transient seed: Long = Utils.random.nextLong) + preservesPartitioning: Boolean, + @transient private val seed: Long = Utils.random.nextLong) extends RDD[U](prev) { @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 081c721f23687..7dd2bc5d7cd72 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1666,7 +1666,7 @@ abstract class RDD[T: ClassTag]( import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" - val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + val storageInfo = rdd.context.getRDDStorageInfo(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), bytesToString(info.externalBlockStoreSize), bytesToString(info.diskSize))) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 0e43520870c0a..429514b4f6bee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -36,7 +36,7 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 44667281c1063..540cbd688b63b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOr import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.google.common.base.Objects import org.apache.spark.{Logging, SparkContext} @@ -67,6 +68,8 @@ private[spark] class RDDOperationScope( } } + override def hashCode(): Int = Objects.hashCode(id, name, parent) + override def toString: String = toJson } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 35d8b0bfd18c5..1c3b5da19ceba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} * An RDD that reads from checkpoint files previously written to reliable storage. */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val checkpointPath: String) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 1df8eef5ff2b9..e9f6060301ba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.SerializableConfiguration * An implementation of checkpointing that writes the RDD data to reliable storage. * This allows drivers to be restarted on failure with previously computed state. */ -private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { // The directory to which the associated RDD has been checkpointed to diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 2dc47f95937cb..cb15d912bbfb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -37,7 +39,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { */ // TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C]( +class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index fa3fecc80cb63..0228c54e0511c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Ut private[spark] class SqlNewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends SparkPartition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -61,9 +61,9 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - @transient sc : SparkContext, + sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], + @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) @@ -86,7 +86,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a4fa301b06e3..25ec685eff5ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -63,15 +63,17 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => + def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) + : Dependency[_] = { if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part, serializer) } } + Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) } override def getPartitions: Array[Partition] = { @@ -105,7 +107,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = { dependencies(depNum) match { case oneToOneDependency: OneToOneDependency[_] => val dependencyPartition = partition.narrowDeps(depNum).get.split diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 3986645350a82..66cf4369da2ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class UnionPartition[T: ClassTag]( idx: Int, - @transient rdd: RDD[T], + @transient private val rdd: RDD[T], val parentRddIndex: Int, - @transient parentRddPartitionIndex: Int) + @transient private val parentRddPartitionIndex: Int) extends Partition { var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index b3c64394abc76..70bf04de6400d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]], + @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e277ae28d588f..32931d59acb18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -37,7 +37,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) * @tparam T parent RDD item type */ private[spark] -class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 7409ac8859991..f25710bb5bd6e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) +private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging { private[this] val maxRetries = RpcUtils.numRetries(conf) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index fc17542abf81d..ad67e1c5ad4d5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -87,9 +87,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use lazy because the Actor needs to use `endpointRef`. + // Use defered function because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. - lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) @@ -272,13 +272,20 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } private[akka] class AkkaRpcEndpointRef( - @transient defaultAddress: RpcAddress, - @transient _actorRef: => ActorRef, - @transient conf: SparkConf, - @transient initInConstructor: Boolean = true) + @transient private val defaultAddress: RpcAddress, + @transient private val _actorRef: () => ActorRef, + conf: SparkConf, + initInConstructor: Boolean) extends RpcEndpointRef(conf) with Logging { - lazy val actorRef = _actorRef + def this( + defaultAddress: RpcAddress, + _actorRef: ActorRef, + conf: SparkConf) = { + this(defaultAddress, () => _actorRef, conf, true) + } + + lazy val actorRef = _actorRef() override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 11d123eec43ca..b6bff64ee368e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -34,9 +34,15 @@ class AccumulableInfo private[spark] ( override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value + this.update == acc.update && this.value == acc.value && + this.internal == acc.internal case _ => false } + + override def hashCode(): Int = { + val state = Seq(id, name, update, value, internal) + state.map(_.hashCode).reduceLeft(31 * _ + _) + } } object AccumulableInfo { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index daf9b0f95273e..09e963f5cdf60 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -250,11 +250,12 @@ class DAGScheduler( case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, firstJobId) + getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - stage } } @@ -365,16 +366,6 @@ class DAGScheduler( parents.toList } - /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { - val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (parentsWithNoMapStage.nonEmpty) { - val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) - shuffleToMapStage(currentShufDep.shuffleId) = stage - } - } - /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] @@ -391,11 +382,9 @@ class DAGScheduler( if (!shuffleToMapStage.contains(shufDep.shuffleId)) { parents.push(shufDep) } - - waitingForVisit.push(shufDep.rdd) case _ => - waitingForVisit.push(dep.rdd) } + waitingForVisit.push(dep.rdd) } } } @@ -1052,10 +1041,11 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head), + shuffleStage.outputLocs.map(_.headOption.orNull), changeEpoch = true) clearCacheLocs() + if (shuffleStage.outputLocs.contains(Nil)) { // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this @@ -1064,27 +1054,9 @@ class DAGScheduler( shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) .map(_._2).mkString(", ")) submitStage(shuffleStage) - } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (shuffleStage <- waitingStages) { - logInfo("Missing parents for " + shuffleStage + ": " + - getMissingParentStages(shuffleStage)) - } - for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) - { - newlyRunnable += shuffleStage - } - waitingStages --= newlyRunnable - runningStages ++= newlyRunnable - for { - shuffleStage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(shuffleStage) - } { - logInfo("Submitting " + shuffleStage + " (" + - shuffleStage.rdd + "), which is now runnable") - submitMissingTasks(shuffleStage, jobId) - } } + + // Note: newly runnable stages will be submitted below when we submit waiting stages } } @@ -1101,7 +1073,6 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") } else { - // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. @@ -1117,6 +1088,11 @@ class DAGScheduler( if (disallowStageRetryForTest) { abortStage(failedStage, "Fetch failure will not retry stage due to testing config", None) + } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { + abortStage(failedStage, s"$failedStage (${failedStage.name}) " + + s"has failed the maximum allowable number of " + + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + + s"Most recent failure reason: ${failureMessage}", None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1182,7 +1158,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head) + val locs = stage.outputLocs.map(_.headOption.orNull) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { @@ -1240,10 +1216,17 @@ class DAGScheduler( if (errorMessage.isEmpty) { logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() } else { stage.latestInfo.stageFailed(errorMessage.get) logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) } + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 2bc43a9186449..0a98c69b89ea5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -23,16 +23,20 @@ import org.apache.spark.executor.ExecutorExitCode * Represents an explanation for a executor or whole slave failing or exiting. */ private[spark] -class ExecutorLossReason(val message: String) { +class ExecutorLossReason(val message: String) extends Serializable { override def toString: String = message } private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String) + extends ExecutorLossReason(reason) + +private[spark] object ExecutorExited { + def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = { + ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode)) + } } private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} + extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 5821afea98982..551e39a81b695 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -83,8 +83,8 @@ private[spark] class Pool( null } - override def executorLost(executorId: String, host: String) { - schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) + override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) { + schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } override def checkSpeculatableTasks(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c4dc080e2b22b..fb693721a9cb6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -44,7 +44,7 @@ private[spark] class ResultTask[T, U]( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient locs: Seq[TaskLocation], + locs: Seq[TaskLocation], val outputId: Int, internalAccumulators: Seq[Accumulator[Long]]) extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index a87ef030e69c2..ab00bc8f0bf4e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,7 +42,7 @@ private[spark] trait Schedulable { def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit + def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 1cf06856ffbc2..c086535782c23 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.CallSite * be updated for each attempt. * */ -private[spark] abstract class Stage( +private[scheduler] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, @@ -92,6 +92,29 @@ private[spark] abstract class Stage( */ private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + /** + * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these + * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * We keep track of each attempt ID that has failed to avoid recording duplicate failures if + * multiple tasks from the same stage attempt fail (SPARK-5945). + */ + private val fetchFailedAttemptIds = new HashSet[Int] + + private[scheduler] def clearFailures() : Unit = { + fetchFailedAttemptIds.clear() + } + + /** + * Check whether we should abort the failedStage due to multiple consecutive fetch failures. + * + * This method updates the running set of failed stage attempts and returns + * true if the number of failures exceeds the allowable number of failures. + */ + private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { + fetchFailedAttemptIds.add(stageAttemptId) + fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES + } + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -110,3 +133,8 @@ private[spark] abstract class Stage( case _ => false } } + +private[scheduler] object Stage { + // The number of consecutive failures allowed before a stage is aborted + val MAX_CONSECUTIVE_FETCH_FAILURES = 4 +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1705e7f962de2..1c7bfe89c02ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -332,7 +332,8 @@ private[spark] class TaskSchedulerImpl( // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) + removeExecutor(execId, + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) } } @@ -464,7 +465,7 @@ private[spark] class TaskSchedulerImpl( if (activeExecutorIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) + removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { // We may get multiple executorLost() calls with different loss reasons. For example, one @@ -482,7 +483,7 @@ private[spark] class TaskSchedulerImpl( } /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { + private def removeExecutor(executorId: String, reason: ExecutorLossReason) { activeExecutorIds -= executorId val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) @@ -497,7 +498,7 @@ private[spark] class TaskSchedulerImpl( } } executorIdToHost -= executorId - rootPool.executorLost(executorId, host) + rootPool.executorLost(executorId, host, reason) } def executorAdded(execId: String, host: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 818b95d67f6be..62af9031b9f8b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -709,6 +709,11 @@ private[spark] class TaskSetManager( } ef.exception + case e: ExecutorLostFailure if e.isNormalExit => + logInfo(s"Task $tid failed because while it was being computed, its executor" + + s" exited normally. Not marking the task as failed.") + None + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) None @@ -722,10 +727,9 @@ private[spark] class TaskSetManager( put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { - // If a task failed because its attempt to commit was denied, do not count this failure - // towards failing the stage. This is intended to prevent spurious stage failures in cases - // where many speculative tasks are launched and denied to commit. + if (!isZombie && state != TaskState.KILLED + && reason.isInstanceOf[TaskFailedReason] + && reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -778,7 +782,7 @@ private[spark] class TaskSetManager( } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { + override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) // Re-enqueue pending tasks for this host based on the status of the cluster. Note @@ -809,9 +813,12 @@ private[spark] class TaskSetManager( } } } - // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) + val isNormalExit: Boolean = reason match { + case exited: ExecutorExited => exited.isNormalExit + case _ => false + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 06f5438433b6e..d94743677783f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -70,7 +71,8 @@ private[spark] object CoarseGrainedClusterMessages { case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) + extends CoarseGrainedClusterMessage case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage @@ -92,6 +94,10 @@ private[spark] object CoarseGrainedClusterMessages { hostToLocalTaskCount: Map[String, Int]) extends CoarseGrainedClusterMessage + // Check if an executor was force-killed but for a normal reason. + // This could be the case if the executor is preempted, for instance. + case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5730a87f960a0..18771f79b44bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** @@ -82,7 +83,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[RpcAddress, String] + protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -128,6 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { @@ -185,8 +187,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, - "remote Rpc client disassociated")) + addressToExecutorId + .get(remoteAddress) + .foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated"))) } // Make fake resource offers on just one executor @@ -227,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String): Unit = { + def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -239,9 +242,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, SlaveLost(reason)) + scheduler.executorLost(executorId, reason) listenerBus.post( - SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -263,8 +266,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint( - CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) + driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + } + + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new DriverEndpoint(rpcEnv, properties) } def stopExecutors() { @@ -304,7 +310,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: ExecutorLossReason) { try { driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bbe51b4a09a22..27491ecf8b97d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,7 +23,7 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( @@ -135,11 +135,11 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) + case Some(code) => ExecutorExited(code, isNormalExit = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) - removeExecutor(fullId.split("/")(1), reason.toString) + removeExecutor(fullId.split("/")(1), reason) } override def sufficientResourcesRegistered(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 044f6288fabdd..6a4b536dee191 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,12 +17,13 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Future, ExecutionContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler._ import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{ThreadUtils, RpcUtils} @@ -43,8 +44,10 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) @@ -53,7 +56,7 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean]( + yarnSchedulerEndpointRef.askWithRetry[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } @@ -61,7 +64,7 @@ private[spark] abstract class YarnSchedulerBackend( * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -90,6 +93,41 @@ private[spark] abstract class YarnSchedulerBackend( } } + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. In particular, the executor may have exited due to the executor + * having been preempted. If the executor "exited normally" according to the application + * master then we pass that information down to the TaskSetManager to inform the + * TaskSetManager that tasks on that lost executor should not count towards a job failure. + * + * TODO there's a race condition where while we are querying the ApplicationMaster for + * the executor loss reason, there is the potential that tasks will be scheduled on + * the executor that failed. We should fix this by having this onDisconnected event + * also "blacklist" executors so that tasks are not assigned to them. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + /** * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ @@ -101,6 +139,33 @@ private[spark] abstract class YarnSchedulerBackend( ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => { + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + } + future onFailure { + case NonFatal(e) => { + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + } + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + } + } + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") @@ -113,6 +178,7 @@ private[spark] abstract class YarnSchedulerBackend( removeExecutor(executorId, reason) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => amEndpoint match { @@ -143,7 +209,6 @@ private[spark] abstract class YarnSchedulerBackend( logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) } - } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 452c32d5411cd..65cb5016cfcc9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -222,6 +222,10 @@ private[spark] class CoarseMesosSchedulerBackend( markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -364,7 +368,7 @@ private[spark] class CoarseMesosSchedulerBackend( if (slaveIdToTaskId.containsKey(slaveId)) { val taskId: Int = slaveIdToTaskId.get(slaveId) taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) + removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) } // TODO: This assumes one Spark executor per Mesos slave, // which may no longer be true after SPARK-5095 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 07da9242b9922..a6d9374eb9e8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,7 +29,6 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} - import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem @@ -375,21 +374,20 @@ private[spark] class MesosClusterScheduler( val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") envBuilder.addVariables( Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) - val cmdOptions = generateCmdOption(desc).mkString(" ") val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image") val executorUri = desc.schedulerProperties.get("spark.executor.uri") .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) - val appArguments = desc.command.arguments.mkString(" ") - val (executable, jar) = if (dockerDefined) { + // Gets the path to run spark-submit, and the path to the Mesos sandbox. + val (executable, sandboxPath) = if (dockerDefined) { // Application jar is automatically downloaded in the mounted sandbox by Mesos, // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable. - ("./bin/spark-submit", s"$$MESOS_SANDBOX/${desc.jarUrl.split("/").last}") + ("./bin/spark-submit", "$MESOS_SANDBOX") } else if (executorUri.isDefined) { builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) val folderBasename = executorUri.get.split('/').last.split('.').head val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" - val cmdJar = s"../${desc.jarUrl.split("/").last}" - (cmdExecutable, cmdJar) + // Sandbox path points to the parent folder as we chdir into the folderBasename. + (cmdExecutable, "..") } else { val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") .orElse(conf.getOption("spark.home")) @@ -398,30 +396,50 @@ private[spark] class MesosClusterScheduler( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath - val cmdJar = desc.jarUrl.split("/").last - (cmdExecutable, cmdJar) + // Sandbox points to the current directory by default with Mesos. + (cmdExecutable, ".") } - builder.setValue(s"$executable $cmdOptions $jar $appArguments") + val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") + val appArguments = desc.command.arguments.mkString(" ") + builder.setValue(s"$executable $cmdOptions $primaryResource $appArguments") builder.setEnvironment(envBuilder.build()) conf.getOption("spark.mesos.uris").map { uris => setupUris(uris, builder) } + desc.schedulerProperties.get("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + setupUris(pyFiles, builder) + } builder.build() } - private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + private def generateCmdOption(desc: MesosDriverDescription, sandboxPath: String): Seq[String] = { var options = Seq( "--name", desc.schedulerProperties("spark.app.name"), - "--class", desc.command.mainClass, "--master", s"mesos://${conf.get("spark.master")}", "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + + // Assume empty main class means we're running python + if (!desc.command.mainClass.equals("")) { + options ++= Seq("--class", desc.command.mainClass) + } + desc.schedulerProperties.get("spark.executor.memory").map { v => options ++= Seq("--executor-memory", v) } desc.schedulerProperties.get("spark.cores.max").map { v => options ++= Seq("--total-executor-cores", v) } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + val formattedFiles = pyFiles.split(",") + .map { path => new File(sandboxPath, path.split("/").last).toString() } + .mkString(",") + options ++= Seq("--py-files", formattedFiles) + } options } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 2e424054be785..8edf7007a5daf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils - /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -127,7 +126,7 @@ private[spark] class MesosSchedulerBackend( } val builder = MesosExecutorInfo.newBuilder() val (resourcesAfterCpu, usedCpuResources) = - partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) @@ -390,7 +389,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) + recordSlaveLost(d, slaveId, ExecutorExited(status, isNormalExit = false)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index b977711e7d5ad..c5195c1143a8f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -35,7 +35,6 @@ import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, Roaring import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -362,9 +361,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[PutBlock], - classOf[GotBlock], - classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], classOf[RoaringBitmap], diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index df7bbd64247dd..75f22f642b9d1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -159,7 +159,7 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage mapId: Int, context: TaskContext): ShuffleWriter[K, V] = { handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) val env = SparkEnv.get new UnsafeShuffleWriter( diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala similarity index 60% rename from core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala rename to core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala index 1cd13d887c6f6..f6e46ae9a481a 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala @@ -15,23 +15,10 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.storage -import java.net.InetSocketAddress +import org.apache.spark.SparkException -import org.apache.spark.util.Utils - -private[nio] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) -} - - -private[nio] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) - } -} +private[spark] +case class BlockFetchException(messages: String, throwable: Throwable) + extends SparkException(messages, throwable) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fefaef0ab82c8..d31aa68eb6954 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,6 +23,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import scala.util.Random import sun.nio.ch.DirectBuffer @@ -600,10 +601,26 @@ private[spark] class BlockManager( private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) + var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + val data = try { + blockTransferService.fetchBlockSync( + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + } catch { + case NonFatal(e) => + numFetchFailures += 1 + if (numFetchFailures == locations.size) { + // An exception is thrown while fetching this block from all locations + throw new BlockFetchException(s"Failed to fetch block from" + + s" ${locations.size} locations. Most recent failure cause:", e) + } else { + // This location failed, so we retry fetch from a different one by returning null here + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $numFetchFailures)", e) + null + } + } if (data != null) { if (asBlockResult) { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3f8d26e1d4cab..f7e84a2c2e14c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -164,7 +164,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a759ceb96ec1e..0d0448feb5b06 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -260,10 +260,7 @@ final class ShuffleBlockFetcherIterator( fetchRequests ++= Utils.randomize(remoteRequests) // Send out initial requests for blocks, up to our maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) @@ -296,10 +293,7 @@ final class ShuffleBlockFetcherIterator( case _ => } // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() result match { case FailureFetchResult(blockId, address, e) => @@ -315,6 +309,14 @@ final class ShuffleBlockFetcherIterator( } } + private def fetchUpToMaxBytes(): Unit = { + // Send fetch requests up to maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + } + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 150d82b3930ef..1b49dca9dc78b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -94,7 +94,7 @@ private[spark] object ClosureCleaner extends Logging { if (cls.isPrimitive) { cls match { case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Character.TYPE => new java.lang.Character('\u0000') case java.lang.Void.TYPE => // This should not happen because `Foo(void x) {}` does not compile. throw new IllegalStateException("Unexpected void parameter in constructor") diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index cbc94fd6d54d9..24f78744ad74c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,8 +362,9 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) - case ExecutorLostFailure(executorId) => - ("Executor ID" -> executorId) + case ExecutorLostFailure(executorId, isNormalExit) => + ("Executor ID" -> executorId) ~ + ("Normal Exit" -> isNormalExit) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -794,8 +795,10 @@ private[spark] object JsonProtocol { case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => + val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). + map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown")) + ExecutorLostFailure(executorId.getOrElse("Unknown"), isNormalExit.getOrElse(false)) case `unknownReason` => UnknownReason } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 61ff9b89ec1c1..db4a8b304ec3e 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -217,7 +217,9 @@ private [util] class SparkShutdownHookManager { } Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + val fsPriority = classOf[FileSystem] + .getField("SHUTDOWN_HOOK_PRIORITY") + .get(null) // static field, the value is not used .asInstanceOf[Int] val shm = shmClass.getMethod("get").invoke(null) shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 9c15b1188d91c..7ab67fc3a2de9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable { */ def capacity: Int = numWords * 64 + /** + * Clear all set bits. + */ + def clear(): Unit = { + var i = 0 + while (i < numWords) { + words(i) = 0L + i += 1 + } + } + /** * Set all the bits up to a given index */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 19287edbaf166..31230d5978b2a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -188,6 +188,12 @@ private[spark] class ExternalSorter[K, V, C]( private val spills = new ArrayBuffer[SpilledFile] + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + override def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -291,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 7138b4b8e4533..1e8476c4a047e 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -79,32 +79,30 @@ private[spark] class RollingFileAppender( val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() val rolloverFile = new File( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile - try { - logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") - if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) - logInfo(s"Rolled over $activeFile to $rolloverFile") - } else { - // In case the rollover file name clashes, make a unique file name. - // The resultant file names are long and ugly, so this is used only - // if there is a name collision. This can be avoided by the using - // the right pattern such that name collisions do not occur. - var i = 0 - var altRolloverFile: File = null - do { - altRolloverFile = new File(activeFile.getParent, - s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile - i += 1 - } while (i < 10000 && altRolloverFile.exists) - - logWarning(s"Rollover file $rolloverFile already exists, " + - s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) - } + logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") + if (activeFile.exists) { + if (!rolloverFile.exists) { + Files.move(activeFile, rolloverFile) + logInfo(s"Rolled over $activeFile to $rolloverFile") } else { - logWarning(s"File $activeFile does not exist") + // In case the rollover file name clashes, make a unique file name. + // The resultant file names are long and ugly, so this is used only + // if there is a name collision. This can be avoided by the using + // the right pattern such that name collisions do not occur. + var i = 0 + var altRolloverFile: File = null + do { + altRolloverFile = new File(activeFile.getParent, + s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile + i += 1 + } while (i < 10000 && altRolloverFile.exists) + + logWarning(s"Rollover file $rolloverFile already exists, " + + s"rolled over $activeFile to file $altRolloverFile") + Files.move(activeFile, altRolloverFile) } + } else { + logWarning(s"File $activeFile does not exist") } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ebd3d61ae7324..fd8f7f39b7cc8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -90,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -103,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -133,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -165,47 +165,49 @@ public void randomSplit() { @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -214,10 +216,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -228,35 +230,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -265,7 +269,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -278,7 +282,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -301,7 +305,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -317,10 +321,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -390,18 +394,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -411,23 +414,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -439,27 +441,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -471,16 +472,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -548,11 +549,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -570,20 +571,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -602,11 +603,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -690,7 +691,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -743,6 +744,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -766,14 +768,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -809,7 +811,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -844,7 +846,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -870,26 +872,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -897,36 +898,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -953,7 +954,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -972,8 +973,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -982,7 +983,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -994,9 +995,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1046,7 +1047,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1075,16 +1076,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1093,7 +1094,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1110,7 +1111,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1131,14 +1132,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1162,7 +1163,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1180,24 +1181,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1210,24 +1210,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1251,9 +1250,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1267,23 +1266,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1296,16 +1294,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1313,8 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1414,8 +1411,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1448,20 +1445,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1496,21 +1493,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1523,7 +1520,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1534,23 +1531,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1561,15 +1558,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1587,7 +1584,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1598,7 +1595,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1615,7 +1612,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1623,12 +1620,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1641,7 +1638,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1649,25 +1646,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1679,7 +1676,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1716,7 +1713,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1745,7 +1742,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 94650be536b5f..a266b0c36e0fa 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -530,8 +530,9 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { writer.insertRecordIntoSorter(new Tuple2(1, 1)); newPeakMemory = writer.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { - // We allocated a new page for this record, so peak memory should change + if (i % numRecordsPerPage == 0 && i != 0) { + // The first page is allocated in constructor, another page will be allocated after + // every numRecordsPerPage records (peak memory should change). assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { assertEquals(previousPeakMemory, newPeakMemory); diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d343bb95cb68c..4d70bfed909b6 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -483,7 +483,7 @@ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) + def cogroup[K: ClassTag, V: ClassTag](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) : RDD[(K, Array[Iterable[V]])] = { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 418763f4e5ffa..fdb00aafc4a48 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -506,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index f34aefca4eb18..f29160d834082 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -125,6 +125,47 @@ class SecurityManagerSuite extends SparkFunSuite { } + test("set security with * in acls") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "*") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAcls with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for modifyAcls with * + securityManager.setModifyAcls(Set("user4"), "*") + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + + securityManager.setAdminAcls("user1,user2") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions("user6") === false) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for adminAcls with * + securityManager.setAdminAcls("user1,*") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() val expectedAlgorithms = Set( diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 48509f0759a3b..cda2b245526f7 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } @@ -145,18 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } + throwable.foreach { t => throw t } } test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -165,20 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === null) + throwable.foreach { t => throw t } } test("set and get local properties in parent-children thread") { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -188,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) + throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 96e456d889ac3..9693e32bf6af6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -366,6 +366,18 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) } + test("client does not send 'SPARK_ENV_LOADED' env var by default") { + val environmentVariables = Map("SPARK_VAR" -> "1", "SPARK_ENV_LOADED" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1")) + } + + test("client includes mesos env vars") { + val environmentVariables = Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1", "OTHER_VAR" -> "1") + val filteredVariables = RestSubmissionClient.filterSystemEnvironment(environmentVariables) + assert(filteredVariables == Map("SPARK_VAR" -> "1", "MESOS_VAR" -> "1")) + } + /* --------------------- * | Helper methods | * --------------------- */ diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index d3218a548efc7..44eb5a0469122 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -286,6 +286,10 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Option[Long]): Long = { val taskMetrics = new ArrayBuffer[Long]() + + // Avoid receiving earlier taskEnd events + sc.listenerBus.waitUntilEmpty(500) + sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { collector(taskEnd).foreach(taskMetrics += _) diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala deleted file mode 100644 index 5e364cc0edeb2..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.nio._ - -import scala.concurrent.duration._ -import scala.concurrent.{Await, TimeoutException} -import scala.language.postfixOps - -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils - -/** - * Test the ConnectionManager with various security settings. - */ -class ConnectionManagerSuite extends SparkFunSuite { - - test("security default off") { - val conf = new SparkConf - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var receivedMessage = false - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - receivedMessage = true - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) - - assert(receivedMessage == true) - - manager.stop() - } - - test("security on same password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - conf.set("spark.app.id", "app-id") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - val managerServer = new ConnectionManager(0, conf, securityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val count = 10 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - }) - - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch password") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.app.id", "app-id") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = conf.clone.set("spark.authenticate.secret", "bad") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - // Expect managerServer to close connection, which we'll report as an error: - intercept[IOException] { - Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) - } - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("security mismatch auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "good") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 1).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - assert(false) - } catch { - case i: IOException => - assert(true) - case e: TimeoutException => { - // we should timeout here since the client can't do the negotiation - assert(true) - } - } - }) - - assert(numReceivedServerMessages == 0) - assert(numReceivedMessages == 0) - manager.stop() - managerServer.stop() - } - - test("security auth off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - var numReceivedMessages = 0 - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedMessages += 1 - None - }) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - val badsecurityManager = new SecurityManager(badconf) - val managerServer = new ConnectionManager(0, badconf, badsecurityManager) - var numReceivedServerMessages = 0 - - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - numReceivedServerMessages += 1 - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - (0 until 10).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(managerServer.id, bufferMessage) - }).foreach(f => { - try { - val g = Await.result(f, 1 second) - } catch { - case e: Exception => { - assert(false) - } - } - }) - assert(numReceivedServerMessages == 10) - assert(numReceivedMessages == 0) - - manager.stop() - managerServer.stop() - } - - test("Ack error message") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - val securityManager = new SecurityManager(conf) - val manager = new ConnectionManager(0, conf, securityManager) - val managerServer = new ConnectionManager(0, conf, securityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception("Custom exception text") - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - val exception = intercept[IOException] { - Await.result(future, 1 second) - } - assert(Utils.exceptionString(exception).contains("Custom exception text")) - - manager.stop() - managerServer.stop() - - } - - test("sendMessageReliably timeout") { - val clientConf = new SparkConf - clientConf.set("spark.authenticate", "false") - val ackTimeoutS = 30 - clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeoutS}s") - - val clientSecurityManager = new SecurityManager(clientConf) - val manager = new ConnectionManager(0, clientConf, clientSecurityManager) - - val serverConf = new SparkConf - serverConf.set("spark.authenticate", "false") - val serverSecurityManager = new SecurityManager(serverConf) - val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) - managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // sleep 60 sec > ack timeout for simulating server slow down or hang up - Thread.sleep(ackTimeoutS * 3 * 1000) - None - }) - - val size = 10 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - - val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - - // Future should throw IOException in 30 sec. - // Otherwise TimeoutExcepton is thrown from Await.result. - // We expect TimeoutException is not thrown. - intercept[IOException] { - Await.result(future, (ackTimeoutS * 2) second) - } - - manager.stop() - managerServer.stop() - } - -} - diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index f65349e3e3585..16a92f54f9368 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -38,6 +38,13 @@ class RDDOperationScopeSuite extends SparkFunSuite with BeforeAndAfter { sc.stop() } + test("equals and hashCode") { + val opScope1 = new RDDOperationScope("scope1", id = "1") + val opScope2 = new RDDOperationScope("scope1", id = "1") + assert(opScope1 === opScope2) + assert(opScope1.hashCode() === opScope2.hashCode()) + } + test("getAllScopes") { assert(scope1.getAllScopes === Seq(scope1)) assert(scope2.getAllScopes === Seq(scope1, scope2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2e8688cf41d99..1b9ff740ff530 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -26,11 +26,11 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite -import org.apache.spark.executor.TaskMetrics class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -285,6 +285,15 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("equals and hashCode AccumulableInfo") { + val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) + val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + assert(accInfo1 !== accInfo2) + assert(accInfo2 === accInfo3) + assert(accInfo2.hashCode() === accInfo3.hashCode()) + } + test("cache location preferences w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil).cache() val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -473,6 +482,282 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + + // Helper function to validate state when creating tests for task failures + private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { + assert(stageAttempt.stageId === stageId) + assert(stageAttempt.stageAttemptId == attempt) + } + + + // Helper functions to extract commonly used code in Fetch Failure test cases + private def setupStageAbortTest(sc: SparkContext) { + sc.listenerBus.addListener(new EndListener()) + ended = false + jobResult = null + } + + // Create a new Listener to confirm that the listenerBus sees the JobEnd message + // when we abort the stage. This message will also be consumed by the EventLoggingListener + // so this will propagate up to the user. + var ended = false + var jobResult : JobResult = null + + class EndListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + ended = true + } + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * successfully. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param numShufflePartitions - The number of partitions in the next stage + */ + private def completeShuffleMapStageSuccessfully( + stageId: Int, + attemptIdx: Int, + numShufflePartitions: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { + case (task, idx) => + (Success, makeMapStatus("host" + ('A' + idx).toChar, numShufflePartitions)) + }.toSeq) + } + + /** + * Common code to get the next stage attempt, confirm it's the one we expect, and complete it + * with all FetchFailure. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + * @param shuffleDep - The shuffle dependency of the stage with a fetch failure + */ + private def completeNextStageWithFetchFailure( + stageId: Int, + attemptIdx: Int, + shuffleDep: ShuffleDependency[_, _, _]): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + }.toSeq) + } + + /** + * Common code to get the next result stage attempt, confirm it's the one we expect, and + * complete it with a success where we return 42. + * + * @param stageId - The current stageId + * @param attemptIdx - The current attempt count + */ + private def completeNextResultStageWithSuccess(stageId: Int, attemptIdx: Int): Unit = { + val stageAttempt = taskSets.last + checkStageId(stageId, attemptIdx, stageAttempt) + assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) + complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map(_ => (Success, 42)).toSeq) + } + + /** + * In this test, we simulate a job where many tasks in the same stage fail. We want to show + * that many fetch failures inside a single stage attempt do not trigger an abort + * on their own, but only when there are enough failing stage attempts. + */ + test("Single stage fetch failure should not abort the stage.") { + setupStageAbortTest(sc) + + val parts = 8 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep)) + submit(reduceRdd, (0 until parts).toArray) + + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts) + + completeNextStageWithFetchFailure(1, 0, shuffleDep) + + // Resubmit and confirm that now all is well + scheduler.resubmitFailedStages() + + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Complete stage 0 and then stage 1 with a "42" + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts) + completeNextResultStageWithSuccess(1, 1) + + // Confirm job finished succesfully + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === (0 until parts).map { idx => idx -> 42 }.toMap) + assertDataStructuresEmpty() + } + + /** + * In this test we simulate a job failure where the first stage completes successfully and + * the second stage fails due to a fetch failure. Multiple successive fetch failures of a stage + * trigger an overall job abort to avoid endless retries. + */ + test("Multiple consecutive stage fetch failures should lead to job being aborted.") { + setupStageAbortTest(sc) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDep) + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } else { + // Stage should have been aborted and removed from running stages + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended) + jobResult match { + case JobFailed(reason) => + assert(reason.getMessage.contains("ResultStage 1 () has failed the maximum")) + case other => fail(s"expected JobFailed, not $other") + } + } + } + } + + /** + * In this test, we create a job with two consecutive shuffles, and simulate 2 failures for each + * shuffle fetch. In total In total, the job has had four failures overall but not four failures + * for a particular stage, and as such should not be aborted. + */ + test("Failures in different stages should not trigger an overall abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + + // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, + // stage 2 fails. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + // Complete all the tasks for the current attempt of stage 0 successfully + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail all these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + } else { + completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) + + // Fail stage 2 + completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, + shuffleDepTwo) + } + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + } + + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) + + // Succeed stage2 with a "42" + completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty() + } + + /** + * In this test we demonstrate that only consecutive failures trigger a stage abort. A stage may + * fail multiple times, succeed, then fail a few more times (because its run again by downstream + * dependencies). The total number of failed attempts for one stage will go over the limit, + * but that doesn't matter, since they have successes in the middle. + */ + test("Non-consecutive stage failures don't trigger abort") { + setupStageAbortTest(sc) + + val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache() + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache() + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + + // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. + for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + // Make each task in stage 0 success + completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) + + // Now we should have a new taskSet, for a new attempt of stage 1. + // Fail these tasks with FetchFailure + completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + } + + // Rerun stage 0 and 1 to step through the task set + completeShuffleMapStageSuccessfully(0, 3, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 3, numShufflePartitions = 1) + + // Fail stage 2 so that stage 1 is resubmitted when we call scheduler.resubmitFailedStages() + completeNextStageWithFetchFailure(2, 0, shuffleDepTwo) + + scheduler.resubmitFailedStages() + + // Rerun stage 0 to step through the task set + completeShuffleMapStageSuccessfully(0, 4, numShufflePartitions = 2) + + // Now again, fail stage 1 (up to MAX_FAILURES) but confirm that this doesn't trigger an abort + // since we succeeded in between. + completeNextStageWithFetchFailure(1, 4, shuffleDepOne) + + scheduler.resubmitFailedStages() + + // Confirm we have not yet aborted + assert(scheduler.runningStages.nonEmpty) + assert(!ended) + + // Next, succeed all and confirm output + // Rerun stage 0 + 1 + completeShuffleMapStageSuccessfully(0, 5, numShufflePartitions = 2) + completeShuffleMapStageSuccessfully(1, 5, numShufflePartitions = 1) + + // Succeed stage 2 and verify results + completeNextResultStageWithSuccess(2, 1) + + assertDataStructuresEmpty() + sc.listenerBus.waitUntilEmpty(1000) + assert(ended === true) + assert(results === Map(0 -> 42)) + } + test("trivial shuffle with multiple fetch failures") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -527,6 +812,7 @@ class DAGSchedulerSuite } // The map stage should have been submitted. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) complete(taskSets(0), Seq( @@ -650,27 +936,64 @@ class DAGSchedulerSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) + // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA")) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) + + // now start completing some tasks in the shuffle map stage, under different hosts + // and epochs, and make sure scheduler updates its state correctly val taskSet = taskSets(0) + val shuffleStage = scheduler.stageIdToStage(taskSet.stageId).asInstanceOf[ShuffleMapStage] + assert(shuffleStage.numAvailableOutputs === 0) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 0) + + // should work because it's a non-failed host (so the available map outputs will increase) + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostB", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - // should work because it's a new epoch + runEvent(CompletionEvent( + taskSet.tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 1) + + // should work because it's a new epoch, which will increase the number of available map + // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", - reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSet.tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.size), + null, + createFakeTaskInfo(), + null)) + assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + + // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -810,7 +1133,7 @@ class DAGSchedulerSuite submit(finalRdd, Array(0)) cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - // complete stage 2 + // complete stage 0 complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) @@ -818,7 +1141,7 @@ class DAGSchedulerSuite complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // pretend stage 0 failed because hostA went down + // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index edbdb485c5ea4..f0eadf240943e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -334,7 +334,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Now mark host2 as dead sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") + manager.executorLost("exec2", "host2", SlaveLost()) // nothing should be chosen assert(manager.resourceOffer("exec1", "host1", ANY) === None) @@ -504,13 +504,36 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) // test if the valid locality is recomputed when the executor is lost sched.removeExecutor("execC") - manager.executorLost("execC", "host2") + manager.executorLost("execC", "host2", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) sched.removeExecutor("execD") - manager.executorLost("execD", "host1") + manager.executorLost("execD", "host1", SlaveLost()) assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } + test("Executors are added but exit normally while running tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + sched.addExecutor("execA", "host1") + manager.executorAdded() + sched.addExecutor("execC", "host2") + manager.executorAdded() + assert(manager.resourceOffer("exec1", "host1", ANY).isDefined) + sched.removeExecutor("execA") + manager.executorLost("execA", "host1", ExecutorExited(143, true, "Normal termination")) + assert(!sched.taskSetsFailed.contains(taskSet.id)) + assert(manager.resourceOffer("execC", "host2", ANY).isDefined) + sched.removeExecutor("execC") + manager.executorLost("execC", "host2", ExecutorExited(1, false, "Abnormal termination")) + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + test("test RACK_LOCAL tasks") { // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") @@ -721,8 +744,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) sched.removeExecutor("execA") sched.removeExecutor("execB.2") - manager.executorLost("execA", "host1") - manager.executorLost("execB.2", "host2") + manager.executorLost("execA", "host1", SlaveLost()) + manager.executorLost("execB.2", "host2", SlaveLost()) clock.advance(LOCALITY_WAIT_MS * 4) sched.addExecutor("execC", "host3") manager.executorAdded() diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 319b3173e7a6e..c4dc560031207 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -42,6 +42,38 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + test("Use configured mesosExecutor.cores for ExecutorInfo") { + val mesosExecutorCores = 3 + val conf = new SparkConf + conf.set("spark.mesos.mesosExecutor.cores", mesosExecutorCores.toString) + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + + val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val resources = Arrays.asList( + mesosSchedulerBackend.createResource("cpus", 4), + mesosSchedulerBackend.createResource("mem", 1024)) + // uri is null. + val (executorInfo, _) = mesosSchedulerBackend.createExecutorInfo(resources, "test-id") + val executorResources = executorInfo.getResourcesList + val cpus = executorResources.asScala.find(_.getName.equals("cpus")).get.getScalar.getValue + + assert(cpus === mesosExecutorCores) + } + test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") @@ -263,7 +295,6 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")) .setHostname(s"host${id.toString}").build() - val mesosOffers = new java.util.ArrayList[Offer] mesosOffers.add(offer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala new file mode 100644 index 0000000000000..4d5f599fb12ab --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDependencySuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark._ + +case class KeyClass() + +case class ValueClass() + +case class CombinerClass() + +class ShuffleDependencySuite extends SparkFunSuite with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + + test("key, value, and combiner classes correct in shuffle dependency without aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .groupByKey() + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!dep.mapSideCombine, "Test requires that no map-side aggregator is defined") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + } + + test("key, value, and combiner classes available in shuffle dependency with aggregation") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .aggregateByKey(CombinerClass())({ case (a, b) => a }, { case (a, b) => a }) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.mapSideCombine && dep.aggregator.isDefined, "Test requires map-side aggregation") + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == Some(classOf[CombinerClass].getName)) + } + + test("combineByKey null combiner class tag handled correctly") { + sc = new SparkContext("local", "test", conf.clone()) + val rdd = sc.parallelize(1 to 5, 4) + .map(key => (KeyClass(), ValueClass())) + .combineByKey((v: ValueClass) => v, + (c: AnyRef, v: ValueClass) => c, + (c1: AnyRef, c2: AnyRef) => c1) + val dep = rdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(dep.keyClassName == classOf[KeyClass].getName) + assert(dep.valueClassName == classOf[ValueClass].getName) + assert(dep.combinerClassName == None) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 0f5ba46f69c2f..eb5af70d57aec 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,10 +26,10 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -38,7 +38,7 @@ import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) @@ -59,7 +59,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f480fd107a0c2..34bb4952e7246 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -30,10 +30,10 @@ import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -44,9 +44,10 @@ import org.apache.spark.util._ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false) + private val conf = new SparkConf(false).set("spark.app.id", "test") var store: BlockManager = null var store2: BlockManager = null + var store3: BlockManager = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") @@ -65,7 +66,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") @@ -99,6 +100,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store2.stop() store2 = null } + if (store3 != null) { + store3.stop() + store3 = null + } rpcEnv.shutdown() rpcEnv.awaitTermination() rpcEnv = null @@ -443,6 +448,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { + val origTimeoutOpt = conf.getOption("spark.network.timeout") + try { + conf.set("spark.network.timeout", "2s") + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store3 = makeBlockManager(8000, "executor3") + val list1 = List(new Array[Byte](4000)) + store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + var list1Get = store.getRemoteBytes("list1") + assert(list1Get.isDefined, "list1Get expected to be fetched") + // block manager exit + store2.stop() + store2 = null + list1Get = store.getRemoteBytes("list1") + // get `list1` block + assert(list1Get.isDefined, "list1Get expected to be fetched") + store3.stop() + store3 = null + // exception throw because there is no locations + intercept[BlockFetchException] { + list1Get = store.getRemoteBytes("list1") + } + } finally { + origTimeoutOpt match { + case Some(t) => conf.set("spark.network.timeout", t) + case None => conf.remove("spark.network.timeout") + } + } + } + test("in-memory LRU storage") { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) @@ -782,7 +819,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -796,7 +833,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") + assert(store.getSingle("a1").isEmpty, "a1 should not be in store") } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 343a4139b0ca8..47e548ef0d442 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,7 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure("100")) + testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) // BlockId @@ -295,10 +295,10 @@ class JsonProtocolSuite extends SparkFunSuite { test("ExecutorLostFailure backward compatibility") { // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. - val executorLostFailure = ExecutorLostFailure("100") + val executorLostFailure = ExecutorLostFailure("100", true) val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) .removeField({ _._1 == "Executor ID" }) - val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true) assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -577,8 +577,10 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + case (ExecutorLostFailure(execId1, isNormalExit1), + ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) + assert(isNormalExit1 === isNormalExit2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py index 8aaa250bd7e29..db9c680a4bad3 100755 --- a/dev/create-release/generate-contributors.py +++ b/dev/create-release/generate-contributors.py @@ -178,13 +178,16 @@ def populate(issue_type, components): author_info[author][issue_type].add(component) # Find issues and components associated with this commit for issue in issues: - jira_issue = jira_client.issue(issue) - jira_type = jira_issue.fields.issuetype.name - jira_type = translate_issue_type(jira_type, issue, warnings) - jira_components = [translate_component(c.name, _hash, warnings)\ - for c in jira_issue.fields.components] - all_components = set(jira_components + commit_components) - populate(jira_type, all_components) + try: + jira_issue = jira_client.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, _hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + except Exception as e: + print "Unexpected error:", e # For docs without an associated JIRA, manually add it ourselves if is_docs(title) and not issues: populate("documentation", commit_components) @@ -223,7 +226,8 @@ def populate(issue_type, components): # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672 if author in invalid_authors and invalid_authors[author]: author = author + "/" + "/".join(invalid_authors[author]) - line = " * %s -- %s" % (author, contribution) + #line = " * %s -- %s" % (author, contribution) + line = author contributors_file.write(line + "\n") contributors_file.close() print "Contributors list is successfully written to %s!" % contributors_file_name diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index e462302f28423..3563fe3cc3c03 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -138,3 +138,30 @@ lee19 - Lee lockwobr - Brian Lockwood navis - Navis Ryu pparkkin - Paavo Parkkinen +HyukjinKwon - Hyukjin Kwon +JDrit - Joseph Batchik +JuhongPark - Juhong Park +KaiXinXiaoLei - KaiXinXIaoLei +NamelessAnalyst - NamelessAnalyst +alyaxey - Alex Slusarenko +baishuo - Shuo Bai +fe2s - Oleksiy Dyagilev +felixcheung - Felix Cheung +feynmanliang - Feynman Liang +josepablocam - Jose Cambronero +kai-zeng - Kai Zeng +mosessky - mosessky +msannell - Michael Sannella +nishkamravi2 - Nishkam Ravi +noel-smith - Noel Smith +petz2000 - Patrick Baier +qiansl127 - Shilei Qian +rahulpalamuttam - Rahul Palamuttam +rowan000 - Rowan Chattaway +sarutak - Kousuke Saruta +sethah - Seth Hendrickson +small-wang - Wang Wei +stanzhai - Stan Zhai +tien-dungle - Tien-Dung Le +xuchenCN - Xu Chen +zhangjiajin - Zhang JiaJin diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 51ab25a6a5bd8..7f152b7f53559 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -24,7 +24,11 @@ try: from jira.client import JIRA - from jira.exceptions import JIRAError + # Old versions have JIRAError in exceptions package, new (0.5+) in utils. + try: + from jira.exceptions import JIRAError + except ImportError: + from jira.utils import JIRAError except ImportError: print "This tool requires the jira-python library" print "Install using 'sudo pip install jira'" diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile index b90aef3655dee..fb3f267fe5c78 100644 --- a/docker/spark-mesos/Dockerfile +++ b/docker/spark-mesos/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && \ apt-get install -y python libnss3 openjdk-7-jre-headless curl RUN mkdir /opt/spark && \ - curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + curl http://www.apache.org/dyn/closer.lua/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ | tar -xzC /opt ENV SPARK_HOME /opt/spark ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286ce..347ca4a7af989 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/building-spark.md b/docs/building-spark.md index f133eb96d9a21..4db32cfd628bc 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -61,12 +61,13 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* *For Java 8 and above this step is not required.* -* *If using `build/mvn` and `MAVEN_OPTS` were not already set, the script will automate this for you.* + +* For Java 8 and above this step is not required. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. # Specifying the Hadoop Version -Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: +Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the `hadoop.version` property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: @@ -91,7 +92,7 @@ mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package {% endhighlight %} -You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. +You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. Examples: @@ -125,7 +126,7 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -Dskip # Building for Scala 2.11 To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.11 mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package Spark does not yet support its JDBC component for Scala 2.11. @@ -163,11 +164,9 @@ the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: -``` - $ mvn install - $ cd core - $ mvn scala:cc -``` + $ mvn install + $ cd core + $ mvn scala:cc # Building Spark with IntelliJ IDEA or Eclipse @@ -193,11 +192,11 @@ then ship it over to the cluster. We are investigating the exact cause for this. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT -Maven is the official recommendation for packaging Spark, and is the "build of reference". +Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. But SBT is supported for day-to-day development since it can provide much faster iterative compilation. More advanced developers may wish to use SBT. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 7079de546e2f5..faaf154d243f5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -5,18 +5,19 @@ title: Cluster Mode Overview This document gives a short overview of how Spark runs on clusters, to make it easier to understand the components involved. Read through the [application submission guide](submitting-applications.html) -to submit applications to a cluster. +to learn about launching applications on a cluster. # Components -Spark applications run as independent sets of processes on a cluster, coordinated by the SparkContext +Spark applications run as independent sets of processes on a cluster, coordinated by the `SparkContext` object in your main program (called the _driver program_). + Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ -(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across +(either Spark's own standalone cluster manager, Mesos or YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to -the executors. Finally, SparkContext sends *tasks* for the executors to run. +the executors. Finally, SparkContext sends *tasks* to the executors to run.

Spark cluster components @@ -33,9 +34,9 @@ There are several useful things to note about this architecture: 2. Spark is agnostic to the underlying cluster manager. As long as it can acquire executor processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). -3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config - section](configuration.html#networking)). As such, the driver program must be network +3. The driver program must listen for and accept incoming connections from its executors throughout + its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker nodes, preferably on the same local area network. If you'd like to send requests to the diff --git a/docs/configuration.md b/docs/configuration.md index 77c5cbc7b3196..1a701f18881fe 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -382,17 +382,6 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. -

- - - - @@ -458,9 +447,12 @@ Apart from these, the following properties are also available, and may be useful @@ -1015,7 +1007,11 @@ Apart from these, the following properties are also available, and may be useful @@ -1116,10 +1112,11 @@ Apart from these, the following properties are also available, and may be useful - + @@ -1327,7 +1325,8 @@ Apart from these, the following properties are also available, and may be useful @@ -1349,7 +1348,8 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.blockTransferServicenetty - Implementation to use for transferring shuffle and cached blocks between executors. There - are two implementations available: netty and nio. Netty-based - block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2, and nio block transfer is deprecated in Spark 1.5.0 and will - be removed in Spark 1.6.0. -
spark.shuffle.compress truespark.shuffle.manager sort - Implementation to use for shuffling data. There are two implementations available: - sort and hash. Sort-based shuffle is more memory-efficient and is - the default option starting in 1.2. + Implementation to use for shuffling data. There are three implementations available: + sort, hash and the new (1.5+) tungsten-sort. + Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. + Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly + implementation with a fall back to regular sort based shuffle if its requirements are not + met.
spark.port.maxRetries 16 - Default maximum number of retries when binding to a port before giving up. + Maximum number of retries when binding to a port before giving up. + When a port is given a specific value (non 0), each subsequent retry will + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified + to port + maxRetries.
spark.scheduler.minRegisteredResourcesRatio0.8 for YARN mode; 0.0 otherwise0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode, CPU cores in standalone mode) + (resources are executors in yarn mode, CPU cores in standalone mode and Mesos coarsed-grained + mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config @@ -1286,7 +1283,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. + help debug when things work. Putting a "*" in the list means any user can have the priviledge + of admin.
Empty Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). + user that started the Spark job has access to modify it (kill it for example). Putting a "*" in + the list means any user can have access to modify it.
Empty Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. + user that started the Spark job has view access. Putting a "*" in the list means any user can + have view access to this Spark job.
@@ -1437,6 +1437,19 @@ Apart from these, the following properties are also available, and may be useful #### Spark Streaming + + + + + diff --git a/docs/index.md b/docs/index.md index d85cf12defefd..c0dc2b8d7412a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,7 +90,6 @@ options for deployment: * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 62749909e01dc..58f566c9b4b55 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -121,10 +121,9 @@ import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.classification.RandomForestClassificationModel import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -193,14 +192,11 @@ import org.apache.spark.ml.classification.RandomForestClassifier; import org.apache.spark.ml.classification.RandomForestClassificationModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -268,10 +264,9 @@ from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -327,10 +322,9 @@ import org.apache.spark.ml.regression.RandomForestRegressor import org.apache.spark.ml.regression.RandomForestRegressionModel import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -387,14 +381,11 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.RandomForestRegressionModel; import org.apache.spark.ml.regression.RandomForestRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -450,10 +441,9 @@ from pyspark.ml import Pipeline from pyspark.ml.regression import RandomForestRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -576,10 +566,9 @@ import org.apache.spark.ml.classification.GBTClassifier import org.apache.spark.ml.classification.GBTClassificationModel import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -648,14 +637,10 @@ import org.apache.spark.ml.classification.GBTClassifier; import org.apache.spark.ml.classification.GBTClassificationModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -724,10 +709,9 @@ from pyspark.ml import Pipeline from pyspark.ml.classification import GBTClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -783,10 +767,9 @@ import org.apache.spark.ml.regression.GBTRegressor import org.apache.spark.ml.regression.GBTRegressionModel import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils // Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -844,14 +827,10 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.GBTRegressionModel; import org.apache.spark.ml.regression.GBTRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -908,10 +887,9 @@ from pyspark.ml import Pipeline from pyspark.ml.regression import GBTRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -970,15 +948,14 @@ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifie {% highlight scala %} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SQLContext} val sqlContext = new SQLContext(sc) // parse data into dataframe -val data = MLUtils.loadLibSVMFile(sc, - "data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") +val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) // instantiate multiclass learner and train val ovr = new OneVsRest().setClassifier(new LogisticRegression) @@ -1016,9 +993,6 @@ import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -1026,10 +1000,9 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); -RDD data = MLUtils.loadLibSVMFile(jsc.sc(), - "data/mllib/sample_multiclass_classification_data.txt"); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); DataFrame train = splits[0]; DataFrame test = splits[1]; diff --git a/docs/ml-features.md b/docs/ml-features.md index 90654d1e5a248..a414c21b5c280 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -512,6 +512,25 @@ DataFrame dataset = jsql.createDataFrame(rdd, schema); remover.transform(dataset).show(); {% endhighlight %} + +
+[`StopWordsRemover`](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %} +
## $n$-gram @@ -1160,9 +1179,9 @@ In the example below, we read in a dataset of labeled points and then use `Vecto
{% highlight scala %} import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -1181,16 +1200,12 @@ val indexedData = indexerModel.transform(data) {% highlight java %} import java.util.Map; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD rdd = MLUtils.loadLibSVMFile(sc.sc(), - "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); +DataFrame data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -1211,9 +1226,9 @@ DataFrame indexedData = indexerModel.transform(data);
{% highlight python %} from pyspark.ml.feature import VectorIndexer -from pyspark.mllib.util import MLUtils -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) indexerModel = indexer.fit(data) @@ -1234,10 +1249,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
{% highlight scala %} import org.apache.spark.ml.feature.Normalizer -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") // Normalize each Vector using $L^1$ norm. val normalizer = new Normalizer() @@ -1253,15 +1267,11 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi
{% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Normalize each Vector using $L^1$ norm. Normalizer normalizer = new Normalizer() @@ -1278,11 +1288,10 @@ DataFrame lInfNormData =
{% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import Normalizer -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") # Normalize each Vector using $L^1$ norm. normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) @@ -1316,10 +1325,9 @@ The following example demonstrates how to load a dataset in libsvm format and th
{% highlight scala %} import org.apache.spark.ml.feature.StandardScaler -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1336,16 +1344,12 @@ val scaledData = scalerModel.transform(dataFrame)
{% highlight java %} -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1362,11 +1366,10 @@ DataFrame scaledData = scalerModel.transform(dataFrame);
{% highlight python %} -from pyspark.mllib.util import MLUtils from pyspark.ml.feature import StandardScaler -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -dataFrame = sqlContext.createDataFrame(data) +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False) @@ -1405,10 +1408,9 @@ More details can be found in the API docs for [MinMaxScalerModel](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel). {% highlight scala %} import org.apache.spark.ml.feature.MinMaxScaler -import org.apache.spark.mllib.util.MLUtils -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val dataFrame = sqlContext.createDataFrame(data) +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -1429,13 +1431,10 @@ More details can be found in the API docs for import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.MinMaxScaler; import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; -JavaRDD data = - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); -DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); MinMaxScaler scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaledFeatures"); diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index cdd9d4999fa1b..4e94e2f9c708d 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -59,10 +59,9 @@ $\alpha$ and `regParam` corresponds to $\lambda$.
{% highlight scala %} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.mllib.util.MLUtils // Load training data -val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LogisticRegression() .setMaxIter(10) @@ -81,8 +80,6 @@ println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") {% highlight java %} import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.sql.DataFrame; @@ -98,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + DataFrame training = sqlContext.read.format("libsvm").load(path); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) @@ -118,11 +115,9 @@ public class LogisticRegressionWithElasticNetExample {
{% highlight python %} from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) @@ -251,10 +246,9 @@ regression model and extracting model summary statistics.
{% highlight scala %} import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.mllib.util.MLUtils // Load training data -val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LinearRegression() .setMaxIter(10) @@ -283,8 +277,6 @@ import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegressionModel; import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.sql.DataFrame; @@ -300,7 +292,7 @@ public class LinearRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + DataFrame training = sqlContext.read.format("libsvm").load(path); LinearRegression lr = new LinearRegression() .setMaxIter(10) @@ -329,11 +321,9 @@ public class LinearRegressionWithElasticNetExample { {% highlight python %} from pyspark.ml.regression import LinearRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 065bf4727624f..d8c7bdc63c70e 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -144,7 +144,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)); // Create a labeled point with a negative label and a sparse feature vector. -LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); +LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {0, 2}, new double[] {1.0, 3.0})); {% endhighlight %}
diff --git a/docs/quick-start.md b/docs/quick-start.md index ce2cc9d2169cd..d481fe0ea6d70 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -126,7 +126,7 @@ scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (w wordCounts: spark.RDD[(String, Int)] = spark.ShuffledAggregatedRDD@71f027b8 {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight scala %} scala> wordCounts.collect() @@ -163,7 +163,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations) and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: {% highlight python %} >>> wordCounts.collect() @@ -217,13 +217,13 @@ a cluster, as described in the [programming guide](programming-guide.html#initia
# Self-Contained Applications -Now say we wanted to write a self-contained application using the Spark API. We will walk through a -simple application in both Scala (with SBT), Java (with Maven), and Python. +Suppose we wish to write a self-contained application using the Spark API. We will walk through a +simple application in Scala (with sbt), Java (with Maven), and Python.
-We'll create a very simple Spark application in Scala. So simple, in fact, that it's +We'll create a very simple Spark application in Scala--so simple, in fact, that it's named `SimpleApp.scala`: {% highlight scala %} @@ -259,7 +259,7 @@ object which contains information about our application. Our application depends on the Spark API, so we'll also include an sbt configuration file, -`simple.sbt` which explains that Spark is a dependency. This file also adds a repository that +`simple.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -302,7 +302,7 @@ Lines with a: 46, Lines with b: 23
-This example will use Maven to compile an application jar, but any similar build system will work. +This example will use Maven to compile an application JAR, but any similar build system will work. We'll create a very simple Spark application, `SimpleApp.java`: @@ -374,7 +374,7 @@ $ find . Now, we can package the application using Maven and execute it with `./bin/spark-submit`. {% highlight bash %} -# Package a jar containing your application +# Package a JAR containing your application $ mvn package ... [INFO] Building jar: {..}/{..}/target/simple-project-1.0.jar diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index cfd219ab02e26..247e6ecfbdb86 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -45,7 +45,7 @@ frameworks. You can install Mesos either from source or using prebuilt packages To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a - [mirror](http://www.apache.org/dyn/closer.cgi/mesos/{{site.MESOS_VERSION}}/) + [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) 2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and installing Mesos @@ -157,6 +157,8 @@ From the client, you can submit a job to Mesos cluster by running `spark-submit` to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. + # Mesos Run Modes Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index de0461010daec..383d954409ce4 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,8 +5,6 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. -Python API Flume is not yet available in the Python API. - ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 118ced298f4b0..c751dbb41785a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -50,13 +50,7 @@ all of which are presented in this guide. You will find tabs throughout this guide that let you choose between code snippets of different languages. -**Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream -transformations and almost all the output operations available in Scala and Java interfaces. -However, it only has support for basic sources like text files and text data over sockets. -APIs for additional sources, like Kafka and Flume, will be available in the future. -Further information about available features in the Python API are mentioned throughout this -document; look out for the tag -Python API. +**Note:** There are a few APIs that are either different or not available in Python. Throughout this guide, you will find the tag Python API highlighting these differences. *************************************************************************************************** @@ -683,7 +677,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, Kafka, Kinesis, Flume and MQTT are available in the Python API. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts @@ -725,9 +719,9 @@ Some of these advanced sources are as follows. - **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. -- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. +- **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. - **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information @@ -1813,7 +1807,7 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *[Since Spark 1.2] Configuring write ahead logs* - Since Spark 1.2, +- *Configuring write ahead logs* - Since Spark 1.2, we have introduced _write ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into a write ahead log in the configuration checkpoint directory. This prevents data loss on driver @@ -1828,6 +1822,17 @@ To run a Spark Streaming applications, you need to have the following. stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. +- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming + application to process data as fast as it is being received, the receivers can be rate limited + by setting a maximum rate limit in terms of records / sec. + See the [configuration parameters](configuration.html#spark-streaming) + `spark.streaming.receiver.maxRate` for receivers and `spark.streaming.kafka.maxRatePerPartition` + for Direct Kafka approach. In Spark 1.5, we have introduced a feature called *backpressure* that + eliminate the need to set this rate limit, as Spark Streaming automatically figures out the + rate limits and dynamically adjusts them if the processing conditions change. This backpressure + can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) + `spark.streaming.backpressure.enabled` to `true`. + ### Upgrading Application Code {:.no_toc} diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 11fd7ee0ec8df..3a2361c6d6d2b 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.4.0" +SPARK_EC2_VERSION = "1.5.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -71,6 +71,8 @@ "1.3.0", "1.3.1", "1.4.0", + "1.4.1", + "1.5.0" ]) SPARK_TACHYON_MAP = { @@ -84,6 +86,8 @@ "1.3.0": "0.5.0", "1.3.1": "0.5.0", "1.4.0": "0.6.4", + "1.4.1": "0.6.4", + "1.5.0": "0.7.1" } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -91,7 +95,7 @@ # Default location to get the spark-ec2 scripts (and ami-list) from DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" def setup_external_libs(libs): diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index fa07c1e5017cd..d1b9b8d398dd8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -81,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -135,3 +137,4 @@ object CassandraCQLTest { } } // scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 2e56d24c60c33..1e679bfb55343 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { } } // scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2bf99cb3cba1f..c8780aa83bdbd 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -43,7 +43,7 @@ import org.jboss.netty.handler.codec.compression._ private[streaming] class FlumeInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel, diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 0bc46209b8369..3b936d88abd3e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -46,7 +46,7 @@ import org.apache.spark.streaming.flume.sink._ * @tparam T Class type of the object of this stream */ private[streaming] class FlumePollingInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, val addresses: Seq[InetSocketAddress], val maxBatchSize: Int, val parallelism: Int, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 1000094e93cb3..8a087474d3169 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -58,7 +58,7 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R @@ -79,7 +79,7 @@ class DirectKafkaInputDStream[ override protected[streaming] val rateController: Option[RateController] = { if (RateController.isBackPressureEnabled(ssc.conf)) { Some(new DirectKafkaRateController(id, - RateEstimator.create(ssc.conf, ssc_.graph.batchDuration))) + RateEstimator.create(ssc.conf, context.graph.batchDuration))) } else { None } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 04b2dc10d39ea..38730fecf332a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -48,7 +48,7 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 9db07d0507fea..fbdfbf7e509b3 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -75,11 +75,11 @@ public void testKafkaStream() throws InterruptedException { String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613ca..afcc6cfccd39a 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b767..1e69de46cd35d 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 7c2f18cb35bda..116c170489e96 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class MQTTInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 7cf02d85d73d3..d7de74b350543 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class TwitterInputDStream( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531d..26ec8af455bcf 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce52..14975265ab2ce 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 4611a3ace219b..ee7302a1edbf6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -38,8 +38,8 @@ import org.apache.spark.graphx.impl.EdgeRDDImpl * `impl.ReplicatedVertexView`. */ abstract class EdgeRDD[ED]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index db73a8abc5733..869caa340f52b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -46,7 +46,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @note vertex ids are unique. * @return an RDD containing the vertices in this graph */ - @transient val vertices: VertexRDD[VD] + val vertices: VertexRDD[VD] /** * An RDD containing the edges and their associated attributes. The entries in the RDD contain @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -77,7 +77,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum * }}} */ - @transient val triplets: RDD[EdgeTriplet[VD, ED]] + val triplets: RDD[EdgeTriplet[VD, ED]] /** * Caches the vertices and edges associated with this graph at the specified storage level, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index a9f04b559c3d1..1ef7a78fbcd00 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -55,8 +55,8 @@ import org.apache.spark.graphx.impl.VertexRDDImpl * @tparam VD the vertex attribute associated with each vertex in the set. */ abstract class VertexRDD[VD]( - @transient sc: SparkContext, - @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { + sc: SparkContext, + deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { implicit protected def vdTag: ClassTag[VD] diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..f632dd603c449 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.ml.source.libsvm.DefaultSource diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 6f70b96b17ec6..0a75d5d22280f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21fbe38ca8233..a460262b87e43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) + val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 1e5b0bc4453e4..5f60dea91fcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} @@ -32,6 +34,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams with HasSeed with HasMaxIter with HasTol { /** * Layer sizes including input size and output size. + * Default: Array(1, 1) * @group param */ final val layers: IntArrayParam = new IntArrayParam(this, "layers", @@ -50,6 +53,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams * Data is stacked within partitions. If block size is more than remaining data in * a partition then it is adjusted to the size of this data. * Recommended size is between 10 and 1000. + * Default: 128 * @group expertParam */ final val blockSize: IntParam = new IntParam(this, "blockSize", @@ -179,6 +183,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 56419a0a15952..08df2919a8a87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -38,6 +38,7 @@ class BinaryClassificationEvaluator(override val uid: String) /** * param for metric name in evaluation + * Default: areaUnderROC * @group param */ val metricName: Param[String] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 46314854d5e3a..edad754436455 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -41,6 +41,7 @@ final class Binarizer(override val uid: String) * Param for threshold used to binarize continuous features. * The features greater than the threshold, will be binarized to 1.0. * The features equal to or less than the threshold, will be binarized to 0.0. + * Default: 0.0 * @group param */ val threshold: DoubleParam = diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 938447447a0a2..4c36df75d8aa0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -35,6 +35,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** * The minimum of documents in which a term should appear. + * Default: 0 * @group param */ final val minDocFreq = new IntParam( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d77ea08db657..2a79582625e9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * stop words list */ -private object StopWords { +private[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -98,6 +98,7 @@ class StopWordsRemover(override val uid: String) /** * the stop words set to be filtered out + * Default: [[StopWords.English]] * @group param */ val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") @@ -110,6 +111,7 @@ class StopWordsRemover(override val uid: String) /** * whether to do a case sensitive comparison over the stop words + * Default: false * @group param */ val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", @@ -121,7 +123,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + setDefault(stopWords -> StopWords.English, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 24250e4c4cf92..3a4ab9a857648 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,7 +102,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. * - * @param labels Ordered list of labels, corresponding to indices to be assigned + * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Experimental class StringIndexerModel ( @@ -181,10 +181,10 @@ class StringIndexerModel ( /** * :: Experimental :: - * A [[Transformer]] that maps a column of string indices back to a new column of corresponding - * string values using either the ML attributes of the input column, or if provided using the labels - * supplied by the user. - * All original columns are kept during transformation. + * A [[Transformer]] that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping is either from the ML attributes of the input column, + * or from user-supplied labels (which take precedence over ML attributes). * * @see [[StringIndexer]] for converting strings into indices */ @@ -202,32 +202,23 @@ class IndexToString private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. The default value is an empty array, - * but the empty array is ignored and column metadata used instead. - * @group setParam - */ + /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) /** - * Param for array of labels. - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. + * Optional param for array of labels specifying index-string mapping. + * + * Default: Empty array, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", - "array of labels, if not provided metadata from inputCol is used instead.") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") setDefault(labels, Array.empty[String]) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. - * @group getParam - */ + /** @group getParam */ final def getLabels: Array[String] = $(labels) - /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 61b925c0fdc07..52e0599e38d83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -43,6 +43,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Must be >= 2. * * (default = 20) + * @group param */ val maxCategories = new IntParam(this, "maxCategories", "Threshold for the number of values a categorical feature can take (>= 2)." + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index c5c2272270792..fb3387d4aa9be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -49,6 +49,7 @@ final class VectorSlicer(override val uid: String) /** * An array of indices to select features from a vector column. * There can be no overlap with [[names]]. + * Default: Empty array * @group param */ val indices = new IntArrayParam(this, "indices", @@ -67,6 +68,7 @@ final class VectorSlicer(override val uid: String) * An array of feature names to select features from a vector column. * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s. * There can be no overlap with [[indices]]. + * Default: Empty Array * @group param */ val names = new StringArrayParam(this, "names", diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 5af775a4159ad..9edab3af913ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -39,6 +39,7 @@ private[feature] trait Word2VecBase extends Params /** * The dimension of the code that you want to transform from words. + * Default: 100 * @group param */ final val vectorSize = new IntParam( @@ -50,6 +51,7 @@ private[feature] trait Word2VecBase extends Params /** * Number of partitions for sentences of words. + * Default: 1 * @group param */ final val numPartitions = new IntParam( @@ -62,6 +64,7 @@ private[feature] trait Word2VecBase extends Params /** * The minimum number of times a token must appear to be included in the word2vec model's * vocabulary. + * Default: 5 * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala new file mode 100644 index 0000000000000..a99e2ac4c6913 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import com.github.fommil.netlib.LAPACK.{getInstance => lapack} +import org.netlib.util.intW + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[WeightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + */ +private[ml] class WeightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double) extends Serializable + +/** + * Weighted least squares solver via normal equation. + * Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares + * formulation: + * + * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i + * + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^, + * + * where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by + * [[standardizeLabel]] and [[standardizeFeatures]], respectively. + * + * Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to + * match R's `lm`. + * Turn on [[standardizeLabel]] to match R's `glmnet`. + * + * @param fitIntercept whether to fit intercept. If false, z is 0.0. + * @param regParam L2 regularization parameter (lambda) + * @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the + * population standard deviation of the j-th column of A. Otherwise, + * sigma,,j,, is 1.0. + * @param standardizeLabel whether to standardize label. If true, delta is the population standard + * deviation of the label column b. Otherwise, delta is 1.0. + */ +private[ml] class WeightedLeastSquares( + val fitIntercept: Boolean, + val regParam: Double, + val standardizeFeatures: Boolean, + val standardizeLabel: Boolean) extends Logging with Serializable { + import WeightedLeastSquares._ + + require(regParam >= 0.0, s"regParam cannot be negative: $regParam") + if (regParam == 0.0) { + logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + + /** + * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. + */ + def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) + summary.validate() + logInfo(s"Number of instances: ${summary.count}.") + val triK = summary.triK + val bBar = summary.bBar + val bStd = summary.bStd + val aBar = summary.aBar + val aVar = summary.aVar + val abBar = summary.abBar + val aaBar = summary.aaBar + val aaValues = aaBar.values + + if (fitIntercept) { + // shift centers + // A^T A - aBar aBar^T + RowMatrix.dspr(-1.0, aBar, aaValues) + // A^T b - bBar aBar + BLAS.axpy(-bBar, aBar, abBar) + } + + // add regularization to diagonals + var i = 0 + var j = 2 + while (i < triK) { + var lambda = regParam + if (standardizeFeatures) { + lambda *= aVar(j - 2) + } + if (standardizeLabel) { + // TODO: handle the case when bStd = 0 + lambda /= bStd + } + aaValues(i) += lambda + i += j + j += 1 + } + + val x = choleskySolve(aaBar.values, abBar) + + // compute intercept + val intercept = if (fitIntercept) { + bBar - BLAS.dot(aBar, x) + } else { + 0.0 + } + + new WeightedLeastSquaresModel(x, intercept) + } + + /** + * Solves a symmetric positive definite linear system via Cholesky factorization. + * The input arguments are modified in-place to store the factorization and the solution. + * @param A the upper triangular part of A + * @param bx right-hand side + * @return the solution vector + */ + // TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS + private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = { + val k = bx.size + val info = new intW(0) + lapack.dppsv("U", k, 1, A, bx.values, k, info) + val code = info.`val` + assert(code == 0, s"lapack.dpotrs returned $code.") + bx + } +} + +private[ml] object WeightedLeastSquares { + + /** + * Case class for weighted observations. + * @param w weight, must be positive + * @param a features + * @param b label + */ + case class Instance(w: Double, a: Vector, b: Double) { + require(w >= 0.0, s"Weight cannot be negative: $w.") + } + + /** + * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. + */ + // TODO: consolidate aggregates for summary statistics + private class Aggregator extends Serializable { + var initialized: Boolean = false + var k: Int = _ + var count: Long = _ + var triK: Int = _ + private var wSum: Double = _ + private var wwSum: Double = _ + private var bSum: Double = _ + private var bbSum: Double = _ + private var aSum: DenseVector = _ + private var abSum: DenseVector = _ + private var aaSum: DenseVector = _ + + private def init(k: Int): Unit = { + require(k <= 4096, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to 4096 but got $k.") + this.k = k + triK = k * (k + 1) / 2 + count = 0L + wSum = 0.0 + wwSum = 0.0 + bSum = 0.0 + bbSum = 0.0 + aSum = new DenseVector(Array.ofDim(k)) + abSum = new DenseVector(Array.ofDim(k)) + aaSum = new DenseVector(Array.ofDim(triK)) + initialized = true + } + + /** + * Adds an instance. + */ + def add(instance: Instance): this.type = { + val Instance(w, a, b) = instance + val ak = a.size + if (!initialized) { + init(ak) + initialized = true + } + assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") + count += 1L + wSum += w + wwSum += w * w + bSum += w * b + bbSum += w * b * b + BLAS.axpy(w, a, aSum) + BLAS.axpy(w * b, a, abSum) + RowMatrix.dspr(w, a, aaSum.values) + this + } + + /** + * Merges another [[Aggregator]]. + */ + def merge(other: Aggregator): this.type = { + if (!other.initialized) { + this + } else { + if (!initialized) { + init(other.k) + } + assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}") + count += other.count + wSum += other.wSum + wwSum += other.wwSum + bSum += other.bSum + bbSum += other.bbSum + BLAS.axpy(1.0, other.aSum, aSum) + BLAS.axpy(1.0, other.abSum, abSum) + BLAS.axpy(1.0, other.aaSum, aaSum) + this + } + } + + /** + * Validates that we have seen observations. + */ + def validate(): Unit = { + assert(initialized, "Training dataset is empty.") + assert(wSum > 0.0, "Sum of weights cannot be zero.") + } + + /** + * Weighted mean of features. + */ + def aBar: DenseVector = { + val output = aSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of labels. + */ + def bBar: Double = bSum / wSum + + /** + * Weighted population standard deviation of labels. + */ + def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + + /** + * Weighted mean of (label * features). + */ + def abBar: DenseVector = { + val output = abSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted mean of (features * features^T^). + */ + def aaBar: DenseVector = { + val output = aaSum.copy + BLAS.scal(1.0 / wSum, output) + output + } + + /** + * Weighted population variance of features. + */ + def aVar: DenseVector = { + val variance = Array.ofDim[Double](k) + var i = 0 + var j = 2 + val aaValues = aaSum.values + while (i < triK) { + val l = j - 2 + val aw = aSum(l) / wSum + variance(l) = aaValues(i) / wSum - aw * aw + i += j + j += 1 + } + new DenseVector(variance) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 91c0a5631319d..de32b7218c277 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable { */ final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) - get(param).orElse(getDefault(param)).get + get(param).orElse(getDefault(param)).getOrElse( + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** An alias for [[getOrDefault()]]. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8c16c6149b40d..e9e99ed1db40e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,7 +56,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + + "the cache will get checkpointed every 10 iterations.", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index c26768953e3db..30092170863ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params { private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). + * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 0f33bae30e622..2ff500f291abc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -40,6 +40,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures /** * Param for whether the output sequence should be isotonic/increasing (true) or * antitonic/decreasing (false). + * Default: true * @group param */ final val isotonic: BooleanParam = @@ -202,7 +203,7 @@ class IsotonicRegressionModel private[ml] ( def predictions: Vector = Vectors.dense(oldModel.predictions) override def copy(extra: ParamMap): IsotonicRegressionModel = { - copyValues(new IsotonicRegressionModel(uid, oldModel), extra) + copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } override def transform(dataset: DataFrame): DataFrame = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 884003eb38524..e4602d36ccc87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -310,7 +310,7 @@ class LinearRegressionModel private[ml] ( } override def copy(extra: ParamMap): LinearRegressionModel = { - val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 0000000000000..1f627777fc68d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects + +import org.apache.spark.Logging +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path File path of LibSVM format + * @param numFeatures The number of features + * @param vectorType The type of vector. It can be 'sparse' or 'dense' + * @param sqlContext The Spark SQLContext + */ +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan with Logging with Serializable { + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil + ) + + override def buildScan(): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + val sparse = vectorType == "sparse" + baseRdd.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + Row(pt.label, features) + } + } + + override def hashCode(): Int = { + Objects.hashCode(path, Double.box(numFeatures), vectorType) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => + path == that.path && + numFeatures == that.numFeatures && + vectorType == that.vectorType + case _ => + false + } +} + +/** + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = sqlContext.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * DataFrame df = sqlContext.read.format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] + */ +@Since("1.6.0") +class DefaultSource extends RelationProvider with DataSourceRegister { + + @Since("1.6.0") + override def shortName(): String = "libsvm" + + @Since("1.6.0") + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { + val path = parameters.getOrElse("path", + throw new IllegalArgumentException("'path' must be specified")) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + val vectorType = parameters.getOrElse("vectorType", "sparse") + new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index dbd8d31571d2e..d29f5253c9c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams { +private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { /** * Maximum depth of the tree (>= 0). @@ -96,21 +96,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams { " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * Specifies how often to checkpoint the cached node IDs. - * E.g. 10 means that the cache will get checkpointed every 10 iterations. - * This is only used if cacheNodeIds is true and if the checkpoint directory is set in - * [[org.apache.spark.SparkContext]]. - * Must be >= 1. - * (default = 10) - * @group expertParam - */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + - " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + - " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.", - ParamValidators.gtEq(1)) - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) @@ -150,12 +135,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams { /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group expertSetParam */ + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertSetParam + */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** @group expertGetParam */ - final def getCheckpointInterval: Int = $(checkpointInterval) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index fcb517b5f735e..96a38a3bde960 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField /** - * Helper utilities for tree-based algorithms + * Helper utilities for algorithms using ML metadata */ private[spark] object MetadataUtils { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 7f6163e04bf17..a5902190d4637 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -168,10 +168,9 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) val dataFrame = sqlContext.read.parquet(dataPath) - val dataArray = dataFrame.select("weight", "mu", "sigma").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() val (weights, gaussians) = dataArray.map { case Row(weight: Double, mu: Vector, sigma: Matrix) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 36b124c5d2966..58857c338f546 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -590,12 +590,10 @@ object Word2VecModel extends Loader[Word2VecModel] { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) val dataFrame = sqlContext.read.parquet(dataPath) - - val dataArray = dataFrame.select("word", "vector").collect() - // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) + val dataArray = dataFrame.select("word", "vector").collect() val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap new Word2VecModel(word2VecMap) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index ab475af264dd3..9ee81eda8a8c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging { } } + /** Y += a * x */ + private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = { + require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " + + s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.") + f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1) + } + /** * dot(x, y) */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a423ddafdc09..83779ac88989b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -678,7 +678,8 @@ object RowMatrix { * * @param U the upper triangular part of the matrix packed in an array (column major) */ - private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + // TODO: SPARK-10491 - move this method to linalg.BLAS + private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { // TODO: Find a better home (breeze?) for this method. val n = v.size v match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index a2d85a68cd327..9eab7efc160da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import org.apache.commons.math3.distribution.{ExponentialDistribution, - GammaDistribution, LogNormalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution._ import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -195,3 +194,27 @@ class LogNormalGenerator @Since("1.3.0") ( @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } + +/** + * :: DeveloperApi :: + * Generates i.i.d. samples from the Weibull distribution with the + * given shape and scale parameter. + * + * @param alpha shape parameter for the Weibull distribution. + * @param beta scale parameter for the Weibull distribution. + */ +@DeveloperApi +class WeibullGenerator( + val alpha: Double, + val beta: Double) extends RandomDataGenerator[Double] { + + private val rng = new WeibullDistribution(alpha, beta) + + override def nextValue(): Double = rng.sample() + + override def setSeed(seed: Long): Unit = { + rng.reseedRandomGenerator(seed) + } + + override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index 910eff9540a47..f8cea7ecea6bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -35,11 +35,11 @@ private[mllib] class RandomRDDPartition[T](override val index: Int, } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: RandomDataGenerator[T], - @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { + @transient private val rng: RandomDataGenerator[T], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") @@ -56,12 +56,12 @@ private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, } } -private[mllib] class RandomVectorRDD(@transient sc: SparkContext, +private[mllib] class RandomVectorRDD(sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: RandomDataGenerator[Double], - @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { + @transient private val rng: RandomDataGenerator[Double], + @transient private val seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 618b95b9bd126..fd22eb6dca018 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -22,6 +22,7 @@ import java.util.List; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -63,16 +64,16 @@ public void tearDown() { @Test public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); - assert(lr.getLabelCol().equals("label")); + Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults - assert(model.getThreshold() == 0.5); - assert(model.getFeaturesCol().equals("features")); - assert(model.getPredictionCol().equals("prediction")); - assert(model.getProbabilityCol().equals("probability")); + Assert.assertEquals(0.5, model.getThreshold(), eps); + Assert.assertEquals("features", model.getFeaturesCol()); + Assert.assertEquals("prediction", model.getPredictionCol()); + Assert.assertEquals("probability", model.getProbabilityCol()); } @Test @@ -85,19 +86,19 @@ public void logisticRegressionWithSetters() { .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); LogisticRegression parent = (LogisticRegression) model.parent(); - assert(parent.getMaxIter() == 10); - assert(parent.getRegParam() == 1.0); - assert(parent.getThresholds()[0] == 0.4); - assert(parent.getThresholds()[1] == 0.6); - assert(parent.getThreshold() == 0.6); - assert(model.getThreshold() == 0.6); + Assert.assertEquals(10, parent.getMaxIter()); + Assert.assertEquals(1.0, parent.getRegParam(), eps); + Assert.assertEquals(0.4, parent.getThresholds()[0], eps); + Assert.assertEquals(0.6, parent.getThresholds()[1], eps); + Assert.assertEquals(0.6, parent.getThreshold(), eps); + Assert.assertEquals(0.6, model.getThreshold(), eps); // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { - assert(r.getDouble(0) == 0.0); + Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) @@ -107,17 +108,17 @@ public void logisticRegressionWithSetters() { for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } - assert(foundNonZero); + Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); - assert(parent2.getMaxIter() == 5); - assert(parent2.getRegParam() == 0.1); - assert(parent2.getThreshold() == 0.4); - assert(model2.getThreshold() == 0.4); - assert(model2.getProbabilityCol().equals("theProb")); + Assert.assertEquals(5, parent2.getMaxIter()); + Assert.assertEquals(0.1, parent2.getRegParam(), eps); + Assert.assertEquals(0.4, parent2.getThreshold(), eps); + Assert.assertEquals(0.4, model2.getThreshold(), eps); + Assert.assertEquals("theProb", model2.getProbabilityCol()); } @SuppressWarnings("unchecked") @@ -125,18 +126,18 @@ public void logisticRegressionWithSetters() { public void logisticRegressionPredictorClassifierMethods() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); - assert(model.numClasses() == 2); + Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); for (Row row: trans1.collect()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); - assert(raw.size() == 2); - assert(prob.size() == 2); + Assert.assertEquals(raw.size(), 2); + Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); - assert(Math.abs(prob.apply(1) - probFromRaw1) < eps); - assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps); + Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps); + Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); @@ -145,7 +146,7 @@ public void logisticRegressionPredictorClassifierMethods() { Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); for (int i = 0; i < prob.size(); ++i) { - assert(probOfPred >= prob.apply(i)); + Assert.assertTrue(probOfPred >= prob.apply(i)); } } } @@ -156,6 +157,6 @@ public void logisticRegressionTrainingSummary() { LogisticRegressionModel model = lr.fit(dataset); LogisticRegressionTrainingSummary summary = model.summary(); - assert(summary.totalIterations() == summary.objectiveHistory().length); + Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 8fd7bf55a2e5d..075a62c493f17 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -23,6 +23,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -58,18 +59,18 @@ public void validatePrediction(DataFrame predictionAndLabels) { for (Row r : predictionAndLabels.collect()) { double prediction = r.getAs(0); double label = r.getAs(1); - assert(prediction == label); + assertEquals(label, prediction, 1E-5); } } @Test public void naiveBayesDefaultParams() { NaiveBayes nb = new NaiveBayes(); - assert(nb.getLabelCol() == "label"); - assert(nb.getFeaturesCol() == "features"); - assert(nb.getPredictionCol() == "prediction"); - assert(nb.getSmoothing() == 1.0); - assert(nb.getModelType() == "multinomial"); + assertEquals("label", nb.getLabelCol()); + assertEquals("features", nb.getFeaturesCol()); + assertEquals("prediction", nb.getPredictionCol()); + assertEquals(1.0, nb.getSmoothing(), 1E-5); + assertEquals("multinomial", nb.getModelType()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index d591a456864e4..91c589d00abd5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -60,7 +60,7 @@ public void tearDown() { @Test public void linearRegressionDefaultParams() { LinearRegression lr = new LinearRegression(); - assert(lr.getLabelCol().equals("label")); + assertEquals("label", lr.getLabelCol()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java new file mode 100644 index 0000000000000..2976b38e45031 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm; + +import java.io.File; +import java.io.IOException; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseVector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + private File tempDir; + private String path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + sqlContext = new SQLContext(jsc); + + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + File file = new File(tempDir, "part-00000"); + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, file, Charsets.US_ASCII); + path = tempDir.toURI().toString(); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + Utils.deleteRecursively(tempDir); + } + + @Test + public void verifyLibSVMDF() { + DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + .load(path); + Assert.assertEquals("label", dataset.columns()[0]); + Assert.assertEquals("features", dataset.columns()[1]); + Row r = dataset.first(); + Assert.assertEquals(1.0, r.getDouble(0), 1e-15); + DenseVector v = r.getAs(1); + Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 3349c5022423a..8beea102efd01 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -80,10 +80,10 @@ public void diagonalMatrixConstruction() { assertArrayEquals(sd.toArray(), s.toArray(), 0.0); assertArrayEquals(s.toArray(), ss.toArray(), 0.0); assertArrayEquals(s.values(), ss.values(), 0.0); - assert(s.values().length == 2); - assert(ss.values().length == 2); - assert(s.colPtrs().length == 4); - assert(ss.colPtrs().length == 4); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); } @Test @@ -137,27 +137,27 @@ public void concatenateMatrices() { Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - assert(deHorz1.numRows() == 3); - assert(deHorz2.numRows() == 3); - assert(deHorz3.numRows() == 3); - assert(spHorz.numRows() == 3); - assert(deHorz1.numCols() == 5); - assert(deHorz2.numCols() == 5); - assert(deHorz3.numCols() == 5); - assert(spHorz.numCols() == 5); + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - assert(deVert1.numRows() == 5); - assert(deVert2.numRows() == 5); - assert(deVert3.numRows() == 5); - assert(spVert.numRows() == 5); - assert(deVert1.numCols() == 2); - assert(deVert2.numCols() == 2); - assert(deVert3.numCols() == 2); - assert(spVert.numCols() == 2); + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index f01306f89cb5f..e0d433f566c25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val stopWords = StopWords.English ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala new file mode 100644 index 0000000000000..652f3adb984d3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.optim.WeightedLeastSquares.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + instances = sc.parallelize(Seq( + Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0), + Instance(2.0, Vectors.dense(1.0, 7.0), 19.0), + Instance(3.0, Vectors.dense(2.0, 11.0), 23.0), + Instance(4.0, Vectors.dense(3.0, 13.0), 29.0) + ), 2) + } + + test("WLS against lm") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- lm(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -3.727121 3.009983 + [1] 18.08 6.08 -0.60 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727121, 3.009983), + Vectors.dense(18.08, 6.08, -0.60)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("WLS against glmnet") { + /* + R code: + + library(glmnet) + + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + for (standardize in c(FALSE, TRUE)) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=lambda, + standardize=standardize, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.727117 3.009982 + [1] 0.000000 -3.307532 2.924206 + [1] 0.000000 -2.914790 2.840627 + [1] 0.000000 -1.526575 2.558158 + [1] 0.00000000 0.06984238 2.20488344 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 18.0799727 6.0799832 -0.5999941 + [1] 13.5356178 3.2714044 0.3770744 + [1] 14.064629 3.565802 0.269593 + [1] 10.1238013 0.9708569 1.1475466 + [1] 13.1860638 2.1761382 0.6213134 + */ + + val expected = Seq( + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.727117, 3.009982), + Vectors.dense(0.0, -3.307532, 2.924206), + Vectors.dense(0.0, -2.914790, 2.840627), + Vectors.dense(0.0, -1.526575, 2.558158), + Vectors.dense(0.0, 0.06984238, 2.20488344), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(18.0799727, 6.0799832, -0.5999941), + Vectors.dense(13.5356178, 3.2714044, 0.3770744), + Vectors.dense(14.064629, 3.565802, 0.269593), + Vectors.dense(10.1238013, 0.9708569, 1.1475466), + Vectors.dense(13.1860638, 2.1761382, 0.6213134)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0); + standardizeFeatures <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam, standardizeFeatures, standardizeLabel = true) + .fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 2c878f8372a47..dfab82c8b67ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite { assert(inputCol.toString === s"${uid}__inputCol") + intercept[java.util.NoSuchElementException] { + solver.getOrDefault(solver.handleInvalid) + } + intercept[IllegalArgumentException] { solver.setMaxIter(-1) } @@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{handleInvalid, maxIter, inputCol} val params = solver.params - assert(params.length === 2) - assert(params(0).eq(inputCol), "params must be ordered by name") - assert(params(1).eq(maxIter)) + assert(params.length === 3) + assert(params(0).eq(handleInvalid), "params must be ordered by name") + assert(params(1).eq(inputCol), "params must be ordered by name") + assert(params(2).eq(maxIter)) assert(!solver.isSet(maxIter)) assert(solver.isDefined(maxIter)) @@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite { assert(solver.explainParam(maxIter) === "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === - Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 2759248344531..9d23547f28447 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter} import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter + with HasInputCol { def this() = this(Identifiable.randomUID("testParams")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index c0ab00b68a2f3..59f4193abc8f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -89,6 +90,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ir.getFeatureIndex === 0) val model = ir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "features", "prediction", "weight") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala new file mode 100644 index 0000000000000..997f574e51f6a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var tempDir: File = _ + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + tempDir = Utils.createTempDir() + val file = new File(tempDir, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + + test("select as sparse vector") { + val df = sqlContext.read.format("libsvm").load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[DenseVector](1) + assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select a vector with specifying the longer dimension") { + val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + .load(path) + val row1 = df.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index a5ca1518f82f5..8416771552fd3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.random -import scala.math +import org.apache.commons.math3.special.Gamma import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite { distributionChecks(gamma, expectedMean, expectedStd, 0.1) } } + + test("WeibullGenerator") { + List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map { + case (alpha: Double, beta: Double) => + val weibull = new WeibullGenerator(alpha, beta) + apiChecks(weibull) + + val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta + val expectedVariance = math.exp( + Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean + val expectedStd = math.sqrt(expectedVariance) + distributionChecks(weibull, expectedMean, expectedStd, 0.1) + } + } } diff --git a/network/common/pom.xml b/network/common/pom.xml index 7dc3068ab8cb7..4141fcb8267a5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,6 +48,10 @@ slf4j-api provided + + com.google.code.findbugs + jsr305 + graphx mllib tools diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 88745dc086a04..3b8b6c8ffa375 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -32,634 +32,638 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") */ object MimaExcludes { - def excludes(version: String) = - version match { - case v if v.startsWith("1.5") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // JavaRDDLike is not meant to be extended by user programs - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.partitioner"), - // Modification of private static method - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), - // Mima false positive (was a private[spark] class) - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator"), - // Removing a testing method from a private class - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), - // While private MiMa is still not happy about the changes, - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // The old JSON RDD is removed in favor of streaming Jackson - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), - // local function inside a method - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") - ) ++ Seq( - // SPARK-8479 Add numNonzeros and numActives to Matrix. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numActives") - ) ++ Seq( - // SPARK-8914 Remove RDDApi - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") - ) ++ Seq( - // SPARK-7292 Provide operator to truncate lineage cheaply - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.RDDCheckpointData"), - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.CheckpointRDD") - ) ++ Seq( - // SPARK-8701 Add input metadata in the batch page. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo") - ) ++ Seq( - // SPARK-6797 Support YARN modes for SparkR - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.PairwiseRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.createRWorker"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.StringRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.BaseRRDD.this") - ) ++ Seq( - // SPARK-7422 add argmax for sparse vectors - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.argmax") - ) ++ Seq( - // SPARK-8906 Move all internal data source classes into execution.datasources - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), - // SPARK-9763 Minimize exposure of internal SQL classes - excludePackage("org.apache.spark.sql.parquet"), - excludePackage("org.apache.spark.sql.json"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") - ) ++ Seq( - // SPARK-4751 Dynamic allocation for standalone mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.supportDynamicAllocation") - ) ++ Seq( - // SPARK-9580: Remove SQL test singletons - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext$") - ) ++ Seq( - // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.mllib.linalg.VectorUDT.serialize") - ) + def excludes(version: String) = version match { + case v if v.startsWith("1.6") => + Seq( + MimaBuild.excludeSparkPackage("network") + ) + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // The old JSON RDD is removed in favor of streaming Jackson + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), + // local function inside a method + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") + ) ++ Seq( + // SPARK-8914 Remove RDDApi + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") + ) ++ Seq( + // SPARK-7292 Provide operator to truncate lineage cheaply + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.RDDCheckpointData"), + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.rdd.CheckpointRDD") + ) ++ Seq( + // SPARK-8701 Add input metadata in the batch page. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.streaming.scheduler.InputInfo") + ) ++ Seq( + // SPARK-6797 Support YARN modes for SparkR + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.PairwiseRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.createRWorker"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.RRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.StringRRDD.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.r.BaseRRDD.this") + ) ++ Seq( + // SPARK-7422 add argmax for sparse vectors + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.argmax") + ) ++ Seq( + // SPARK-8906 Move all internal data source classes into execution.datasources + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") + ) ++ Seq( + // SPARK-4751 Dynamic allocation for standalone mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) - case v if v.startsWith("1.4") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partioner to JavaRDD, - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.rdd.JdbcRDD.compute"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") - ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though - // the stage class is defined as private[spark] - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") - ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") - ) ++ Seq( - // SPARK-6492 Fix deadlock in SparkContext.stop() - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + - "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") - )++ Seq( - // SPARK-6693 add tostring with max lines and width for matrix - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.toString") - )++ Seq( - // SPARK-6703 Add getOrCreate method to SparkContext - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") - )++ Seq( - // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.mllib.clustering.LDA$EMOptimizer") - ) ++ Seq( - // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.compressed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toDense"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toSparse"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives"), - // SPARK-7681 add SparseVector support for gemv - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.multiply") - ) ++ Seq( - // Execution should never be included as its always internal. - MimaBuild.excludeSparkPackage("sql.execution"), - // This `protected[sql]` method was removed in 1.3.1 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.checkAnalysis"), - // These `private[sql]` class were removed in 1.4.0: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), - // These test support classes were moved out of src/main and into src/test: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), - // TODO: Remove the following rule once ParquetTest has been moved to src/test. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTest") - ) ++ Seq( - // SPARK-7530 Added StreamingContext.getState() - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.StreamingContext.state_=") - ) ++ Seq( - // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some - // unnecessary type bounds in order to fix some compiler warnings that occurred when - // implementing this interface in Java. Note that ShuffleWriter is private[spark]. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.shuffle.ShuffleWriter") - ) ++ Seq( - // SPARK-6888 make jdbc driver handling user definable - // This patch renames some classes to API friendly names. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") - ) + case v if v.startsWith("1.4") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") + ) ++ Seq( + // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // the stage class is defined as private[spark] + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") + ) ++ Seq( + // SPARK-6492 Fix deadlock in SparkContext.stop() + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + + "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") + )++ Seq( + // SPARK-6693 add tostring with max lines and width for matrix + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") + ) ++ Seq( + // Execution should never be included as its always internal. + MimaBuild.excludeSparkPackage("sql.execution"), + // This `protected[sql]` method was removed in 1.3.1 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.checkAnalysis"), + // These `private[sql]` class were removed in 1.4.0: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.PartitionSpec$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.Partition$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), + // These test support classes were moved out of src/main and into src/test: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") + ) ++ Seq( + // SPARK-7530 Added StreamingContext.getState() + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") + ) - case v if v.startsWith("1.3") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in the 1.2 build. - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") - ) ++ Seq( - // SPARK-2321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfoImpl.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfo.submissionTime") - ) ++ Seq( - // SPARK-4614 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.randn"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.rand") - ) ++ Seq( - // SPARK-5321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.transpose"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + - "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.isTransposed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.foreachActive") - ) ++ Seq( - // SPARK-5540 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), - // SPARK-5536 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") - ) ++ Seq( - // SPARK-3325 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), - // SPARK-2757 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + - "removeAndGetProcessor") - ) ++ Seq( - // SPARK-5123 (SparkSQL data type change) - alpha component only - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.HashingTF.outputDataType"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.outputDataType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.validateInputType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") - ) ++ Seq( - // SPARK-4014 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.taskAttemptId"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.attemptNumber") - ) ++ Seq( - // SPARK-5166 Spark SQL API stabilization - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") - ) ++ Seq( - // SPARK-5270 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.isEmpty") - ) ++ Seq( - // SPARK-5430 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeReduce"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeAggregate") - ) ++ Seq( - // SPARK-5297 Java FileStream do not work with custom key/values - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") - ) ++ Seq( - // SPARK-5315 Spark Streaming Java API returns Scala DStream - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") - ) ++ Seq( - // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.getCheckpointFiles"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.isCheckpointed") - ) ++ Seq( - // SPARK-4789 Standardize ML Prediction APIs - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") - ) ++ Seq( - // SPARK-5814 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") - ) ++ Seq( - // SPARK-4682 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") - ) ++ Seq( - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") - ) + case v if v.startsWith("1.3") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("ml"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in the 1.2 build. + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) ++ Seq( + // SPARK-2321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfoImpl.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") + ) ++ Seq( + // SPARK-5540 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), + // SPARK-5536 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") + ) ++ Seq( + // SPARK-3325 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), + // SPARK-2757 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + + "removeAndGetProcessor") + ) ++ Seq( + // SPARK-5123 (SparkSQL data type change) - alpha component only + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.HashingTF.outputDataType"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.outputDataType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.validateInputType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") + ) ++ Seq( + // SPARK-4014 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.taskAttemptId"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.attemptNumber") + ) ++ Seq( + // SPARK-5166 Spark SQL API stabilization + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") + ) ++ Seq( + // SPARK-5270 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") + ) ++ Seq( + // SPARK-5297 Java FileStream do not work with custom key/values + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") + ) ++ Seq( + // SPARK-5315 Spark Streaming Java API returns Scala DStream + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") + ) ++ Seq( + // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.getCheckpointFiles"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.graphx.Graph.isCheckpointed") + ) ++ Seq( + // SPARK-4789 Standardize ML Prediction APIs + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") + ) ++ Seq( + // SPARK-5814 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") + ) ++ Seq( + // SPARK-4682 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") + ) ++ Seq( + // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") + ) - case v if v.startsWith("1.2") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ - Seq( - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.TaskLocation"), - // Added normL1 and normL2 to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), - // MapStatus should be private[spark] - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.PathResolver"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.client.BlockClientListener"), + case v if v.startsWith("1.2") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.TaskLocation"), + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), + // MapStatus should be private[spark] + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), - // TaskContext was promoted to Abstract class - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.util.collection.SortDataFormat") - ) ++ Seq( - // Adding new methods to the JavaRDDLike trait: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.takeAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.collectAsync") - ) ++ Seq( - // SPARK-3822 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-1209 - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.rdd.PairRDDFunctions") - ) ++ Seq( - // SPARK-4062 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") - ) + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") + ) ++ Seq( + // SPARK-3822 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") + ) - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // Should probably mark this as Experimental - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.getValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry") - ) ++ - Seq( - // Serializer interface change. See SPARK-3045. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.DeserializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.Serializer"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializerInstance") - )++ - Seq( - // Renamed putValues -> putArray + putIterator - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.TachyonStore.putValues") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.flume.FlumeReceiver.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - // Class was missing "@DeveloperApi" annotation in 1.0. - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) ++ - Seq( // Package-private classes removed in SPARK-2341 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ - Seq( // package-private classes removed in MLlib - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") - ) ++ - Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") - ) ++ - Seq( // synthetic methods generated in LabeledPoint - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") - ) ++ - Seq ( // Scala 2.11 compatibility fix - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } -} + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // Should probably mark this as Experimental + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.getValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ + Seq( + // Renamed putValues -> putArray + putIterator + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.DiskStore.putValues"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.TachyonStore.putValues") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.flume.FlumeReceiver.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + // Class was missing "@DeveloperApi" annotation in 1.0. + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) ++ + Seq( // Package-private classes removed in SPARK-2341 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } +} \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt index 51820460ca1a0..c06687d8f197b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,5 +1,3 @@ -scalaVersion := "2.10.4" - resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" diff --git a/python/docs/index.rst b/python/docs/index.rst index f7eede9c3c82a..306ffdb0e0f13 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -29,6 +29,14 @@ Core classes: A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + :class:`pyspark.streaming.StreamingContext` + + Main entry point for Spark Streaming functionality. + + :class:`pyspark.streaming.DStream` + + A Discretized Stream (DStream), the basic abstraction in Spark Streaming. + :class:`pyspark.sql.SQLContext` Main entry point for DataFrame and SQL functionality. diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index 50822c93faba1..fc52a647543e7 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -15,3 +15,24 @@ pyspark.streaming.kafka module :members: :undoc-members: :show-inheritance: + +pyspark.streaming.kinesis module +-------------------------------- +.. automodule:: pyspark.streaming.kinesis + :members: + :undoc-members: + :show-inheritance: + +pyspark.streaming.flume.module +------------------------------ +.. automodule:: pyspark.streaming.flume + :members: + :undoc-members: + :show-inheritance: + +pyspark.streaming.mqtt module +----------------------------- +.. automodule:: pyspark.streaming.mqtt + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f70ac6ed8fe6..8475dfb1c6ad0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -48,6 +48,22 @@ from pyspark.status import * from pyspark.profiler import Profiler, BasicProfiler + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + + # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1b2a52ad64114..a0a1ccbeefb09 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -255,7 +255,7 @@ def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( "It appears that you are attempting to reference SparkContext from a broadcast " - "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "variable, action, or transformation. SparkContext can only be used on the driver, " "not in code that it run on workers. For more information, see SPARK-5063." ) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83f808efc3bf0..88815e561f572 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,12 +26,14 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel'] + 'NaiveBayesModel', 'MultilayerPerceptronClassifier', + 'MultilayerPerceptronClassificationModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): """ Logistic regression. Currently, this class only supports binary classification. @@ -65,17 +67,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - thresholds = Param(Params._dummy(), "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -83,40 +74,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") - #: param for thresholds or cutoffs in binary or multiclass classification - self.thresholds = \ - Param(self, "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -124,13 +98,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -142,32 +116,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -808,6 +756,135 @@ def theta(self): return self._call_java("theta") +@inherit_doc +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasSeed): + """ + Classifier trainer based on the Multilayer Perceptron. + Each layer has sigmoid activation function, output layer has softmax. + Number of inputs has to be equal to the size of feature vectors. + Number of outputs has to be equal to the total number of labels. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (0.0, Vectors.dense([0.0, 0.0])), + ... (1.0, Vectors.dense([0.0, 1.0])), + ... (1.0, Vectors.dense([1.0, 0.0])), + ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> model = mlp.fit(df) + >>> model.layers + [2, 5, 2] + >>> model.weights.size + 27 + >>> testDF = sqlContext.createDataFrame([ + ... (Vectors.dense([1.0, 0.0]),), + ... (Vectors.dense([0.0, 0.0]),)], ["features"]) + >>> model.transform(testDF).show() + +---------+----------+ + | features|prediction| + +---------+----------+ + |[1.0,0.0]| 1.0| + |[0.0,0.0]| 0.0| + +---------+----------+ + ... + """ + + # a placeholder to make it appear in the generated doc + layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + + "neurons and output layer of 10 neurons, default is [1, 1].") + blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is more than " + + "remaining data in a partition then it is adjusted to the size of this " + + "data. Recommended size is between 10 and 1000, default is 128.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + """ + super(MultilayerPerceptronClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) + self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + + "100 neurons and output layer of 10 neurons, default is [1, 1].") + self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is " + + "more than remaining data in a partition then it is adjusted to " + + "the size of this data. Recommended size is between 10 and 1000, " + + "default is 128.") + self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + Sets params for MultilayerPerceptronClassifier. + """ + kwargs = self.setParams._input_kwargs + if layers is None: + return self._set(**kwargs).setLayers([1, 1]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return MultilayerPerceptronClassificationModel(java_model) + + def setLayers(self, value): + """ + Sets the value of :py:attr:`layers`. + """ + self._paramMap[self.layers] = value + return self + + def getLayers(self): + """ + Gets the value of layers or its default value. + """ + return self.getOrDefault(self.layers) + + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + self._paramMap[self.blockSize] = value + return self + + def getBlockSize(self): + """ + Gets the value of blockSize or its default value. + """ + return self.getOrDefault(self.blockSize) + + +class MultilayerPerceptronClassificationModel(JavaModel): + """ + Model fitted by MultilayerPerceptronClassifier. + """ + + @property + def layers(self): + """ + array of layer sizes including input and output layers. + """ + return self._call_java("javaLayers") + + @property + def weights(self): + """ + vector of initial weights for the model that consists of the weights of layers. + """ + return self._call_java("weights") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0626281e200a1..92db8df80280b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,20 +22,23 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', + 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Binarize a column of continuous features given a threshold. >>> df = sqlContext.createDataFrame([(0.5,)], ["values"]) @@ -92,6 +95,8 @@ def getThreshold(self): @inherit_doc class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Maps a column of continuous features to a column of feature buckets. >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) @@ -169,6 +174,8 @@ def getSplits(self): @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performed on the input vector. It returns a real vector of the same length representing the DCT. @@ -232,6 +239,8 @@ def getInverse(self): @inherit_doc class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided "weight" vector. In other words, it scales each column of the dataset by a scalar multiplier. @@ -289,6 +298,8 @@ def getScalingVec(self): @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ + .. note:: Experimental + Maps a sequence of terms to their term frequencies using the hashing trick. @@ -327,6 +338,8 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Compute the Inverse Document Frequency (IDF) given a collection of documents. >>> from pyspark.mllib.linalg import DenseVector @@ -387,14 +400,112 @@ def _create_model(self, java_model): class IDFModel(JavaModel): """ + .. note:: Experimental + Model fitted by IDF. """ +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to a common range [min, max] linearly using column summary + statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + feature E is calculated as, + + Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min + + For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) + + Note that since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> model = mmScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[0.0]| [0.0]| + |[2.0]| [1.0]| + +-----+------+ + ... + """ + + # a placeholder to make it appear in the generated doc + min = Param(Params._dummy(), "min", "Lower bound of the output feature range") + max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + + @keyword_only + def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + """ + super(MinMaxScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) + self.min = Param(self, "min", "Lower bound of the output feature range") + self.max = Param(self, "max", "Upper bound of the output feature range") + self._setDefault(min=0.0, max=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + Sets params for this MinMaxScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + self._paramMap[self.min] = value + return self + + def getMin(self): + """ + Gets the value of min or its default value. + """ + return self.getOrDefault(self.min) + + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + self._paramMap[self.max] = value + return self + + def getMax(self): + """ + Gets the value of max or its default value. + """ + return self.getOrDefault(self.max) + + def _create_model(self, java_model): + return MinMaxScalerModel(java_model) + + +class MinMaxScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MinMaxScaler`. + """ + + @inherit_doc @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A feature transformer that converts the input array of strings into an array of n-grams. Null values in the input array are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string of @@ -463,6 +574,8 @@ def getN(self): @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Normalize a vector to have unit norm using the given p-norm. >>> from pyspark.mllib.linalg import Vectors @@ -519,6 +632,8 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index. @@ -591,6 +706,8 @@ def getDropLast(self): @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an expansion of a product of sums expresses it as a sum of products by using the fact that @@ -649,6 +766,8 @@ def getDegree(self): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text (default) or repeatedly matching the regex (if gaps is false). @@ -746,6 +865,8 @@ def getPattern(self): @inherit_doc class SQLTransformer(JavaTransformer): """ + .. note:: Experimental + Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. @@ -797,6 +918,8 @@ def getStatement(self): @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. @@ -870,6 +993,8 @@ def _create_model(self, java_model): class StandardScalerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StandardScaler. """ @@ -889,8 +1014,10 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): """ + .. note:: Experimental + A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. The indices are in [0, numLabels), ordered by label frequencies. @@ -902,22 +1029,28 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> itd = inverter.transform(td) + >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] """ @keyword_only - def __init__(self, inputCol=None, outputCol=None): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCol=None, outputCol=None) + __init__(self, inputCol=None, outputCol=None, handleInvalid="error") """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) + self._setDefault(handleInvalid="error") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None): + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCol=None, outputCol=None) + setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ kwargs = self.setParams._input_kwargs @@ -929,14 +1062,147 @@ def _create_model(self, java_model): class StringIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by StringIndexer. """ + @property + def labels(self): + """ + Ordered list of labels, corresponding to indices to be assigned. + """ + return self._java_obj.labels + + +@inherit_doc +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A :py:class:`Transformer` that maps a column of indices back to a new column of + corresponding string values. + The index-string mapping is either from the ML attributes of the input column, + or from user-supplied labels (which take precedence over ML attributes). + See L{StringIndexer} for converting strings into indices. + """ + + # a placeholder to make the labels show up in generated doc + labels = Param(Params._dummy(), "labels", + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, labels=None): + """ + __init__(self, inputCol=None, outputCol=None, labels=None) + """ + super(IndexToString, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", + self.uid) + self.labels = Param(self, "labels", + "Optional array of labels specifying index-string mapping. If not" + + " provided or if empty, then metadata from inputCol is used instead.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, labels=None): + """ + setParams(self, inputCol=None, outputCol=None, labels=None) + Sets params for this IndexToString. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setLabels(self, value): + """ + Sets the value of :py:attr:`labels`. + """ + self._paramMap[self.labels] = value + return self + + def getLabels(self): + """ + Gets the value of :py:attr:`labels` or its default value. + """ + return self.getOrDefault(self.labels) + + +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that filters out stop words from input. + Note: null values from input array are preserved unless adding null to stopWords explicitly. + """ + # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + + "comparison over the stop words") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + __init__(self, inputCol=None, outputCol=None, stopWords=None,\ + caseSensitive=false) + """ + super(StopWordsRemover, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", + self.uid) + self.stopWords = Param(self, "stopWords", "The words to be filtered out") + self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + + "sensitive comparison over the stop words") + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords + defaultStopWords = stopWordsObj.English() + self._setDefault(stopWords=defaultStopWords) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + setParams(self, inputCol="input", outputCol="output", stopWords=None,\ + caseSensitive=false) + Sets params for this StopWordRemover. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStopWords(self, value): + """ + Specify the stopwords to be filtered. + """ + self._paramMap[self.stopWords] = value + return self + + def getStopWords(self): + """ + Get the stopwords. + """ + return self.getOrDefault(self.stopWords) + + def setCaseSensitive(self, value): + """ + Set whether to do a case sensitive comparison over the stop words + """ + self._paramMap[self.caseSensitive] = value + return self + + def getCaseSensitive(self): + """ + Get whether to do a case sensitive comparison over the stop words. + """ + return self.getOrDefault(self.caseSensitive) @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ + .. note:: Experimental + A tokenizer that converts the input string to lowercase and then splits it by white spaces. @@ -982,6 +1248,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): """ + .. note:: Experimental + A feature transformer that merges multiple columns into a vector column. >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) @@ -1018,6 +1286,8 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Class for indexing categorical feature columns in a dataset of [[Vector]]. This has 2 usage modes: @@ -1060,6 +1330,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> model = indexer.fit(df) >>> model.transform(df).head().indexed DenseVector([1.0, 0.0]) + >>> model.numFeatures + 2 + >>> model.categoryMaps + {0: {0.0: 0, -1.0: 1}} >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test DenseVector([0.0, 1.0]) >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"} @@ -1117,14 +1391,119 @@ def _create_model(self, java_model): class VectorIndexerModel(JavaModel): """ + .. note:: Experimental + Model fitted by VectorIndexer. """ + @property + def numFeatures(self): + """ + Number of features, i.e., length of Vectors which this transforms. + """ + return self._call_java("numFeatures") + + @property + def categoryMaps(self): + """ + Feature value index. Keys are categorical feature indices (column indices). + Values are maps from original features values to 0-based category indices. + If a feature is not in this map, it is treated as continuous. + """ + return self._call_java("javaCategoryMaps") + + +@inherit_doc +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + This class takes a feature vector and outputs a new feature vector with a subarray + of the original features. + + The subset of features can be specified with either indices (`setIndices()`) + or names (`setNames()`). At least one feature must be selected. Duplicate features + are not allowed, so there can be no overlap between selected indices and names. + + The output vector will order features with the selected indices first (in the order given), + followed by the selected names (in the order given). + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),), + ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),), + ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"]) + >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) + >>> vs.transform(df).head().sliced + DenseVector([2.3, 1.0]) + """ + + # a placeholder to make it appear in the generated doc + indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + names = Param(Params._dummy(), "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " + + "indices.") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + __init__(self, inputCol=None, outputCol=None, indices=None, names=None) + """ + super(VectorSlicer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) + self.indices = Param(self, "indices", "An array of indices to select features from " + + "a vector column. There can be no overlap with names.") + self.names = Param(self, "names", "An array of feature names to select features from " + + "a vector column. These names must be specified by ML " + + "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + + "with indices.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + """ + setParams(self, inputCol=None, outputCol=None, indices=None, names=None): + Sets params for this VectorSlicer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setIndices(self, value): + """ + Sets the value of :py:attr:`indices`. + """ + self._paramMap[self.indices] = value + return self + + def getIndices(self): + """ + Gets the value of indices or its default value. + """ + return self.getOrDefault(self.indices) + + def setNames(self, value): + """ + Sets the value of :py:attr:`names`. + """ + self._paramMap[self.names] = value + return self + + def getNames(self): + """ + Gets the value of names or its default value. + """ + return self.getOrDefault(self.names) + @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): """ + .. note:: Experimental + Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further natural language processing or machine learning process. @@ -1238,6 +1617,8 @@ def _create_model(self, java_model): class Word2VecModel(JavaModel): """ + .. note:: Experimental + Model fitted by Word2Vec. """ @@ -1263,6 +1644,8 @@ def findSynonyms(self, word, num): @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): """ + .. note:: Experimental + PCA trains a model to project vectors to a low-dimensional space using PCA. >>> from pyspark.mllib.linalg import Vectors @@ -1318,6 +1701,8 @@ def _create_model(self, java_model): class PCAModel(JavaModel): """ + .. note:: Experimental + Model fitted by PCA. """ @@ -1401,6 +1786,8 @@ def _create_model(self, java_model): class RFormulaModel(JavaModel): """ + .. note:: Experimental + Model fitted by :py:class:`RFormula`. """ diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 69efc424ec4ef..5b39e5dd4e25b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,7 +121,19 @@ def get$Name(self): ("checkpointInterval", "checkpoint interval (>= 1)", None), ("seed", "random seed", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None)] + ("stepSize", "Step size to be used for each iteration of optimization.", None), + ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + + "out rows with bad values), or error (which will throw an errror). More options may be " + + "added later.", None), + ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), + ("fitIntercept", "whether to fit an intercept term.", "True"), + ("standardization", "whether to standardize the training features before fitting the " + + "model.", "True"), + ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + + "predicting each class. Array must have length equal to the number of classes, with " + + "values >= 0. The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class' threshold.", None)] code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 595124726366d..af1218128602b 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -432,6 +432,144 @@ def getStepSize(self): return self.getOrDefault(self.stepSize) +class HasHandleInvalid(Params): + """ + Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + """ + + # a placeholder to make it appear in the generated doc + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def __init__(self): + super(HasHandleInvalid, self).__init__() + #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + self._paramMap[self.handleInvalid] = value + return self + + def getHandleInvalid(self): + """ + Gets the value of handleInvalid or its default value. + """ + return self.getOrDefault(self.handleInvalid) + + +class HasElasticNetParam(Params): + """ + Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + def __init__(self): + super(HasElasticNetParam, self).__init__() + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(elasticNetParam=0.0) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class HasFitIntercept(Params): + """ + Mixin for param fitIntercept: whether to fit an intercept term.. + """ + + # a placeholder to make it appear in the generated doc + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + + def __init__(self): + super(HasFitIntercept, self).__init__() + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self._setDefault(fitIntercept=True) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + +class HasStandardization(Params): + """ + Mixin for param standardization: whether to standardize the training features before fitting the model.. + """ + + # a placeholder to make it appear in the generated doc + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + + def __init__(self): + super(HasStandardization, self).__init__() + #: param for whether to standardize the training features before fitting the model. + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self._setDefault(standardization=True) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + self._paramMap[self.standardization] = value + return self + + def getStandardization(self): + """ + Gets the value of standardization or its default value. + """ + return self.getOrDefault(self.standardization) + + +class HasThresholds(Params): + """ + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + """ + + # a placeholder to make it appear in the generated doc + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def __init__(self): + super(HasThresholds, self).__init__() + #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value + return self + + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. @@ -444,7 +582,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -460,7 +598,7 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 44f60a769566d..a9503608b7f25 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,7 +28,8 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol): + HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, + HasStandardization): """ Linear regression. @@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction TypeError: Method setParams forces keyword arguments. """ - # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 60e4237293adc..b892318f50bd9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,7 +31,7 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params @@ -258,7 +258,7 @@ def test_idf(self): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") @@ -266,6 +266,22 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEquals(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEquals(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a"]) + class HasInducedError(Params): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 253705bde913e..8218c7c5f801c 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -136,7 +136,8 @@ def _fit(self, dataset): class JavaTransformer(Transformer, JavaWrapper): """ Base class for :py:class:`Transformer`s that wrap Java/Scala - implementations. + implementations. Subclasses should ensure they have the transformer Java object + available as _java_obj. """ __metaclass__ = ABCMeta diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index ad9c891ba1c04..98eaf52866d23 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -44,21 +44,6 @@ from __future__ import absolute_import -def since(version): - """ - A decorator that annotates a function to append the version of Spark the function was added. - """ - import re - indent_p = re.compile(r'\n( +)') - - def deco(f): - indents = indent_p.findall(f.__doc__) - indent = ' ' * (min(len(m) for m in indents) if indents else 0) - f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) - return f - return deco - - from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0948f9b27cd38..9ca8e1f264cfa 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -22,9 +22,9 @@ basestring = str long = int +from pyspark import since from pyspark.context import SparkContext from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since from pyspark.sql.types import * __all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", @@ -91,6 +91,17 @@ def _(self): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ def __init__(self, jc): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") @@ -226,6 +239,9 @@ def __getattr__(self, item): raise AttributeError(item) return self.getField(item) + def __iter__(self): + raise TypeError("Column is not iterable") + # string methods rlike = _bin_op("rlike") like = _bin_op("like") diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0ef46c44644ab..89c8c6e0d94f1 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -26,9 +26,9 @@ from py4j.protocol import Py4JError +from pyspark import since from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter from pyspark.sql.dataframe import DataFrame diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e269ef4304f3f..fb995fa3a76b5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,11 +26,11 @@ else: from itertools import imap as map +from pyspark import since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter @@ -653,25 +653,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+---+ - |summary|age| - +-------+---+ - | count| 2| - | mean|3.5| - | stddev|1.5| - | min| 2| - | max| 5| - +-------+---+ + +-------+------------------+ + |summary| age| + +-------+------------------+ + | count| 2| + | mean| 3.5| + | stddev|2.1213203435596424| + | min| 2| + | max| 5| + +-------+------------------+ >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4b74a501521a5..26b8662718a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,10 +24,9 @@ if sys.version < "3": from itertools import imap as map -from pyspark import SparkContext +from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql import since from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 04594d5a836ce..71c0bccc5eeff 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,8 +15,8 @@ # limitations under the License. # +from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql import since from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3fa6895880a97..f43d8bf646a9e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,8 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD -from pyspark.sql import since +from pyspark import RDD, since from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd32e26c64f22..f2172b7a27d88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,16 +50,17 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException -class UTC(datetime.tzinfo): - """UTC""" - ZERO = datetime.timedelta(0) +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) def utcoffset(self, dt): return self.ZERO - def tzname(self, dt): - return "UTC" - def dst(self, dt): return self.ZERO @@ -167,6 +168,11 @@ def test_decimal_type(self): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + class SQLTests(ReusedPySparkTestCase): @@ -562,7 +568,7 @@ def test_column_operators(self): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) @@ -841,13 +847,22 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_filter_with_datetime_timezone(self): + dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) + dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) + row = Row(date=dt1) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(0, df.filter(df.date == dt2).count()) + self.assertEqual(1, df.filter(df.date > dt2).count()) + self.assertEqual(0, df.filter(df.date < dt2).count()) + def test_time_with_timezone(self): day = datetime.date.today() now = datetime.datetime.now() ts = time.mktime(now.timetuple()) # class in __main__ is not serializable - from pyspark.sql.tests import UTC - utc = UTC() + from pyspark.sql.tests import UTCOffsetTimezone + utc = UTCOffsetTimezone() utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) @@ -1051,6 +1066,15 @@ def test_with_column_with_existing_name(self): keys = self.df.withColumn("key", self.df.key).select("key").collect() self.assertEqual([r.key for r in keys], list(range(100))) + # regression test for SPARK-10417 + def test_column_iterator(self): + + def foo(): + for x in self.df.key: + break + + self.assertRaises(TypeError, foo) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 94e581a78364c..1f86894855cbe 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -168,10 +168,12 @@ def needConversion(self): return True def toInternal(self, d): - return d and d.toordinal() - self.EPOCH_ORDINAL + if d is not None: + return d.toordinal() - self.EPOCH_ORDINAL def fromInternal(self, v): - return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + if v is not None: + return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) class TimestampType(AtomicType): @@ -1174,6 +1176,8 @@ class Row(tuple): >>> row = Row(name="Alice", age=11) >>> row Row(age=11, name='Alice') + >>> row['name'], row['age'] + ('Alice', 11) >>> row.name, row.age ('Alice', 11) @@ -1241,6 +1245,19 @@ def __call__(self, *args): """create new Row object""" return _create_row(self, args) + def __getitem__(self, item): + if isinstance(item, (int, slice)): + return super(Row, self).__getitem__(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return super(Row, self).__getitem__(idx) + except IndexError: + raise KeyError(item) + except ValueError: + raise ValueError(item) + def __getattr__(self, item): if item.startswith("__"): raise AttributeError(item) @@ -1290,8 +1307,11 @@ def can_convert(self, obj): def convert(self, obj, gateway_client): Timestamp = JavaClass("java.sql.Timestamp", gateway_client) - return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) - + seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo + else time.mktime(obj.timetuple())) + t = Timestamp(int(seconds) * 1000) + t.setNanos(obj.microsecond * 1000) + return t # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeConverter()) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index eaf4d7e98620a..57bbe340bbd4d 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -17,8 +17,7 @@ import sys -from pyspark import SparkContext -from pyspark.sql import since +from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index bf609ff0f65fc..33d262558b1fc 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -118,5 +118,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + def run(lines: List[String]): String = run(lines.map(_ + "\n").mkString) } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68fdb4141cf27..64a0c71bbef2a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -168,6 +168,14 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + ^getConfiguration$|^getTaskAttemptID$ + Instead of calling .getConfiguration() or .getTaskAttemptID() directly, + use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1a5de15c61f86..591747b45c376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -560,43 +560,47 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved && !sort.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) - val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions - - // Expressions that have an aggregate can be pushed down. - val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) - - // Attribute references, that are missing from the order but are present in the grouping - // expressions can also be pushed down. - val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) - val missingAttributes = requiredAttributes -- aggregate.outputSet - val validPushdownAttributes = - missingAttributes.filter(a => grouping.exists(a.semanticEquals)) - - // If resolution was successful and we see the ordering either has an aggregate in it or - // it is missing something that is projected away by the aggregate, add the ordering - // the original aggregate operator. - if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) - } - val aggExprsWithOrdering: Seq[NamedExpression] = - resolvedAggregateOrdering ++ originalAggExprs - - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) - } else { - sort + val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } + + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } } + + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. @@ -605,9 +609,7 @@ class Analyzer( } protected def containsAggregate(condition: Expression): Boolean = { - condition - .collect { case ae: AggregateExpression => ae } - .nonEmpty + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d788151..11b4866bf264b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,6 +168,9 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), + expression[Stddev]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), // string functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87c11abbad490..87a3845b2d9e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,6 +297,9 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a49327655..699c4cc63d09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,6 +159,9 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def stddev(e: Expression): Expression = Stddev(e) + def stddev_pop(e: Expression): Expression = StddevPop(e) + def stddev_samp(e: Expression): Expression = StddevSamp(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2db954257be35..f0bce388d959a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba1..02cd0ac0db118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val bufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a95490..ce3dddad87f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -85,6 +85,24 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Stddev(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Stddev(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevPop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevPop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevSamp(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Sum(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Sum(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9cb..f1c47f39043c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + + def isSample: Boolean + + override def asPartial: SplitEvaluation = { + val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() + SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) + } + + override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + +} + +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV($child)" + override def isSample: Boolean = true +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_POP($child)" + override def isSample: Boolean = false +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_SAMP($child)" + override def isSample: Boolean = true +} + +case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) +} + +case class ComputePartialStdFunction ( + expr: Expression, + base: AggregateExpression1 +) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private var partialCount: Long = 0L + + // the mean of data processed so far + private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update average based on this formula: + // avg = avg + (value - avg)/count + private def avgAddFunction (value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), partialAvg) + Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) + } + + // the sum of squares of difference from mean + private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update sum of square of difference from mean based on following formula: + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), prePartialAvg) + val delta2 = Subtract(Cast(value, computeType), partialAvg) + Add(partialMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + val prePartialAvg = partialAvg.copy() + partialCount += 1 + partialAvg.update(avgAddFunction(exprValue), input) + partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), + partialAvg.eval(null), + partialMk.eval(null))) + } +} + +case class MergePartialStd( + child: Expression, + isSample: Boolean +) extends UnaryExpression with AggregateExpression1 { + def this() = this(null, false) // required for serialization + + override def children: Seq[Expression] = child:: Nil + override def nullable: Boolean = false + override def dataType: DataType = DoubleType + override def toString: String = s"MergePartialStd($child)" + override def newInstance(): MergePartialStdFunction = { + new MergePartialStdFunction(child, this, isSample) + } +} + +case class MergePartialStdFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + def this() = this (null, null, false) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private val combineCount = MutableLiteral(zero.eval(null), computeType) + private val combineAvg = MutableLiteral(zero.eval(null), computeType) + private val combineMk = MutableLiteral(zero.eval(null), computeType) + + private def avgUpdateFunction(preCount: Expression, + partialCount: Expression, + partialAvg: Expression): Expression = { + Divide(Add(Multiply(combineAvg, preCount), + Multiply(partialAvg, partialCount)), + Add(preCount, partialCount)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] + + if (evaluatedExpr != null) { + val exprValue = evaluatedExpr.toArray(computeType) + val (partialCount, partialAvg, partialMk) = + (Literal.create(exprValue(0), computeType), + Literal.create(exprValue(1), computeType), + Literal.create(exprValue(2), computeType)) + + if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { + val preCount = combineCount.copy() + combineCount.update(Add(combineCount, partialCount), input) + + val preAvg = combineAvg.copy() + val avgDelta = Subtract(partialAvg, preAvg) + val mkDelta = Multiply(Multiply(avgDelta, avgDelta), + Divide(Multiply(preCount, partialCount), + combineCount)) + + // update average based on following formula + // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) + combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) + + // update sum of square differences from mean based on following formula + // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) + combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) + } + } + } + + override def eval(input: InternalRow): Any = { + val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] + + if (count == 0) null + else if (count < 2) zero.eval(null) + else { + // when total count > 2 + // stddev_samp = sqrt (combineMk/(combineCount -1)) + // stddev_pop = sqrt (combineMk/combineCount) + val varCol = { + if (isSample) { + Divide(combineMk, Cast(Literal(count - 1), computeType)) + } + else { + Divide(combineMk, Cast(Literal(count), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} + +case class StddevFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + + def this() = this(null, null, false) // Required for serialization + + private val computeType = DoubleType + private var curCount: Long = 0L + private val zero = Cast(Literal(0), computeType) + private val curAvg = MutableLiteral(zero.eval(null), computeType) + private val curMk = MutableLiteral(zero.eval(null), computeType) + + private def curAvgAddFunction(value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), curAvg) + Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) + } + private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), preAvg) + val delta2 = Subtract(Cast(value, computeType), curAvg) + Add(curMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val preAvg: MutableLiteral = curAvg.copy() + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + curCount += 1L + curAvg.update(curAvgAddFunction(exprValue), input) + curMk.update(curMkAddFunction(exprValue, preAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + if (curCount == 0) null + else if (curCount < 2) zero.eval(null) + else { + // when total count > 2, + // stddev_samp = sqrt(curMk/(curCount - 1)) + // stddev_pop = sqrt(curMk/curCount) + val varCol = { + if (isSample) { + Divide(curMk, Cast(Literal(curCount - 1), computeType)) + } + else { + Divide(curMk, Cast(Literal(curCount), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf96248feaef7..da3103b4ebb6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -84,11 +84,11 @@ class CodeGenContext { /** * Holding all the functions those will be added into generated class. */ - val addedFuntions: mutable.Map[String, String] = + val addedFunctions: mutable.Map[String, String] = mutable.Map.empty[String, String] def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFuntions += ((funcName, funcCode)) + addedFunctions += ((funcName, funcCode)) } final val JAVA_BOOLEAN = "boolean" @@ -298,8 +298,8 @@ class CodeGenContext { | $body |} """.stripMargin - addNewFunction(name, code) - name + addNewFunction(name, code) + name } functions.map(name => s"$name($row);").mkString("\n") @@ -337,7 +337,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b4d4df8934bd4..793023b9fbed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types.DecimalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e8..2164ddf03d1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -48,7 +48,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val columns = expressions.zipWithIndex.map { case (e, i) => s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n ") + }.mkString("\n") val initColumns = expressions.zipWithIndex.map { case (e, i) => @@ -67,18 +67,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val getCases = (0 until expressions.size).map { i => s"case $i: return c$i;" - }.mkString("\n ") + }.mkString("\n") val updateCases = expressions.zipWithIndex.map { case (e, i) => s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n ") + }.mkString("\n") val specificAccessorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: return c$i;") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val getter = "get" + ctx.primitiveTypeName(jt) s""" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case (e, i) if ctx.javaType(e.dataType) == jt => Some(s"case $i: { c$i = value; return; }") case _ => None - }.mkString("\n ") + }.mkString("\n") if (cases.length > 0) { val setter = "set" + ctx.primitiveTypeName(jt) s""" @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val copyColumns = expressions.zipWithIndex.map { case (e, i) => s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n ") + }.mkString("\n") val code = s""" public SpecificProjection generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b570fe86db1aa..55562facf9652 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -134,7 +134,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") ctx.addMutableState("int", cursor, s"this.$cursor = 0;") - val tmp = ctx.freshName("tmpBuffer") + val tmpBuffer = ctx.freshName("tmpBuffer") val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => val ev = createConvertCode(ctx, input, dt) @@ -144,10 +144,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. - byte[] $tmp = new byte[$numBytes * 2]; + byte[] $tmpBuffer = new byte[$numBytes * 2]; Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, - $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); - $buffer = $tmp; + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; } $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ @@ -206,67 +206,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val tmpBuffer = ctx.freshName("tmpBuffer") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") - val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") + val elementName = ctx.freshName("elementName") - val element = GeneratedExpressionCode( - code = "", - isNull = s"$tmp.isNullAt($index)", - primitive = s"${ctx.getValue(tmp, elementType, index)}" - ) - val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) - - // go through the input array to calculate how many bytes we need. - val calculateNumBytes = elementType match { - case _ if ctx.isPrimitiveType(elementType) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - $numBytes += $elementSize * $numElements; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - $numBytes += 8 * $numElements; - """ - case _ => - val writer = getWriter(elementType) - val elementSize = s"$writer.getSize($elements[$index])" - // TODO(davies): avoid the copy - val unsafeType = elementType match { - case _: StructType => "UnsafeRow" - case _: ArrayType => "UnsafeArrayData" - case _: MapType => "UnsafeMapData" - case _ => ctx.javaType(elementType) - } - val copy = elementType match { - // We reuse the buffer during conversion, need copy it before process next element. - case _: StructType | _: ArrayType | _: MapType => ".copy()" - case _ => "" - } - - val newElements = if (elementType == BinaryType) { - s"new byte[$numElements][]" - } else { - s"new $unsafeType[$numElements]" - } - s""" - final $unsafeType[] $elements = $newElements; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if (!${convertedElement.isNull}) { - $elements[$index] = ${convertedElement.primitive}$copy; - $numBytes += $elementSize; - } - } - """ + val element = { + val code = s"${ctx.javaType(elementType)} $elementName = " + + s"${ctx.getValue(input.primitive, elementType, index)};" + val isNull = s"${input.primitive}.isNullAt($index)" + GeneratedExpressionCode(code, isNull, elementName) } + val convertedElement = createConvertCode(ctx, element, elementType) + val writeElement = elementType match { case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? @@ -292,41 +249,51 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $cursor += $writer.write( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, - $elements[$index]); + ${convertedElement.primitive}); """ } - val checkNull = elementType match { - case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" - case t: DecimalType => s"$elements[$index] == null" + - s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" - case _ => s"$elements[$index] == null" + val checkNull = convertedElement.isNull + (elementType match { + case t: DecimalType => + s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})" + case _ => "" + }) + + val elementSize = elementType match { + // Should we do word align for primitive types? + case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8" + case _ => + val writer = getWriter(elementType) + s"$writer.getSize(${convertedElement.primitive})" } val code = s""" ${input.code} final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final ArrayData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeArrayData) { - $output = (UnsafeArrayData) $tmp; + if (${input.primitive} instanceof UnsafeArrayData) { + $output = (UnsafeArrayData) ${input.primitive}; } else { - final int $numElements = $tmp.numElements(); + final int $numElements = ${input.primitive}.numElements(); final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; - $calculateNumBytes - - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - int $cursor = $fixedSize; for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} if ($checkNull) { // If element is null, write the negative value address into offset region. Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { + $numBytes += $elementSize; + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmpBuffer = new byte[$numBytes * 2]; + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; + } Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } @@ -350,29 +317,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro valueType: DataType): GeneratedExpressionCode = { val output = ctx.freshName("convertedMap") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") - - val keyArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.keyArray()" - ) - val valueArray = GeneratedExpressionCode( - code = "", - isNull = "false", - primitive = s"$tmp.valueArray()" - ) - val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) - val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + val keyArrayName = ctx.freshName("keyArrayName") + val valueArrayName = ctx.freshName("valueArrayName") + + val keyArray = { + val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, keyArrayName) + } + + val valueArray = { + val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();" + val isNull = "false" + GeneratedExpressionCode(code, isNull, valueArrayName) + } + + val convertedKeys = createCodeForArray(ctx, keyArray, keyType) + val convertedValues = createCodeForArray(ctx, valueArray, valueType) val code = s""" ${input.code} final boolean $outputIsNull = ${input.isNull}; UnsafeMapData $output = null; if (!$outputIsNull) { - final MapData $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeMapData) { - $output = (UnsafeMapData) $tmp; + if (${input.primitive} instanceof UnsafeMapData) { + $output = (UnsafeMapData) ${input.primitive}; } else { ${convertedKeys.code} ${convertedValues.code} @@ -393,22 +362,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => val output = ctx.freshName("convertedStruct") val outputIsNull = ctx.freshName("isNull") - val tmp = ctx.freshName("tmp") val fieldTypes = t.fields.map(_.dataType) val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val getFieldCode = ctx.getValue(tmp, dt, i.toString) - val fieldIsNull = s"$tmp.isNullAt($i)" - GeneratedExpressionCode("", fieldIsNull, getFieldCode) + val fieldName = ctx.freshName("fieldName") + val code = s"${ctx.javaType(dt)} $fieldName = " + + s"${ctx.getValue(input.primitive, dt, i.toString)};" + val isNull = s"${input.primitive}.isNullAt($i)" + GeneratedExpressionCode(code, isNull, fieldName) } - val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; final boolean $outputIsNull = ${input.isNull}; if (!$outputIsNull) { - final InternalRow $tmp = ${input.primitive}; - if ($tmp instanceof UnsafeRow) { - $output = (UnsafeRow) $tmp; + if (${input.primitive} instanceof UnsafeRow) { + $output = (UnsafeRow) ${input.primitive}; } else { ${converter.code} $output = ${converter.primitive}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1c546719730b7..82eab5fb3d03a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -48,21 +48,22 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") s""" final boolean ${ev.isNull} = false; - final Object[] values = new Object[${children.size}]; + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - values[$i] = null; + $values[$i] = null; } else { - values[$i] = ${eval.primitive}; + $values[$i] = ${eval.primitive}; } """ }.mkString("\n") + - s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" + s"final ArrayData ${ev.primitive} = new $arrayClass($values);" } override def prettyName: String = "array" @@ -94,21 +95,23 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + final Object[] $values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "struct" @@ -161,21 +164,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName + val rowClass = classOf[GenericInternalRow].getName + val values = ctx.freshName("values") s""" boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + final Object[] $values = new Object[${valExprs.size}]; """ + valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + $values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + $values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final InternalRow ${ev.primitive} = new $rowClass($values);" } override def prettyName: String = "named_struct" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 65706dba7d975..daefc016bc91c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -67,7 +67,7 @@ trait PredicateHelper { /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method - * can be used to determine when is is acceptable to move expression evaluation within a query + * can be used to determine when it is acceptable to move expression evaluation within a query * plan. * * For example consider a join between two relations R(a, b) and S(c, d). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 48d02bb534501..a09d5b6e3ad14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -255,7 +255,7 @@ object StringTranslate { val dict = new HashMap[Character, Character]() var i = 0 while (i < matching.length()) { - val rep = if (i < replace.length()) replace.charAt(i) else '\0' + val rep = if (i < replace.length()) replace.charAt(i) else '\u0000' if (null == dict.get(matching.charAt(i))) { dict.put(matching.charAt(i), rep) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a430000bef653..0f4caec7451a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -434,6 +434,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case (_, Literal(false, BooleanType)) => Literal(false) // a && a => a case (l, r) if l fastEquals r => l + // a && (not(a) || b) => a && b + case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) + case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, @@ -512,6 +517,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case LessThan(l, r) => GreaterThanOrEqual(l, r) // not(l <= r) => l > r case LessThanOrEqual(l, r) => GreaterThan(l, r) + // not(l || r) => not(l) && not(r) + case Or(l, r) => And(Not(l), Not(r)) + // not(l && r) => not(l) or not(r) + case And(l, r) => Or(Not(l), Not(r)) // not(not(e)) => e case Not(e) => e case _ => not diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 73a21884a4710..56a3dd02f9ba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -51,7 +51,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { * filled in automatically by the QueryPlanner using the other execution strategies that are * available. */ - protected def planLater(plan: LogicalPlan) = this.plan(plan).next() + protected def planLater(plan: LogicalPlan): PhysicalPlan = this.plan(plan).next() def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // Obviously a lot to do here still... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b9ca712c1ee1c..53537799517ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,35 +17,12 @@ package org.apache.spark.sql.catalyst.planning -import scala.annotation.tailrec - import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -/** - * A pattern that matches any number of filter operations on top of another relational operator. - * Adjacent filter operators are collected and their conditions are broken up and returned as a - * sequence of conjunctive predicates. - * - * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the - * output and a relational operator. - */ -object FilteredOperation extends PredicateHelper { - type ReturnType = (Seq[Expression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) - - @tailrec - private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { - case Filter(condition, child) => - collectFilters(filters ++ splitConjunctivePredicates(condition), child) - case other => (filters, other) - } -} - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -62,8 +39,9 @@ object PhysicalOperation extends PredicateHelper { } /** - * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two - * examples for alias in-lining/substitution. Before: + * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * Here are two examples for alias in-lining/substitution. + * Before: * {{{ * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 @@ -74,15 +52,15 @@ object PhysicalOperation extends PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - def collectProjectsAndFilters(plan: LogicalPlan): + private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case Project(fields, child) => + case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) - case Filter(condition, child) => + case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) @@ -91,11 +69,11 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { + private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) @@ -195,8 +173,9 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // as join keys. val (joinPredicates, otherPredicates) = condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { - case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || - (canEvaluate(l, right) && canEvaluate(r, left)) => true + case EqualTo(l, r) => + (canEvaluate(l, left) && canEvaluate(r, right)) || + (canEvaluate(l, right) && canEvaluate(r, left)) case _ => false } @@ -204,10 +183,9 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) } - val leftKeys = joinKeys.map(_._1) - val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9bb466ac2d29c..8f8747e105932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -135,16 +135,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { val input = children.flatMap(_.output) + def cleanExpression(e: Expression) = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + BindReferences.bindReference(cleanedExprId, input, allowFailures = true) + case other => BindReferences.bindReference(other, input, allowFailures = true) + } + productIterator.map { // Children are checked using sameResult above. case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case s: Option[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case s: Seq[_] => s.map { - case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case e: Expression => cleanExpression(e) case other => other } case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d3759..c2eeb3c5650ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index e0667c629486d..1d2d007c2b4d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -126,7 +126,7 @@ protected[sql] object AnyDataType extends AbstractDataType { */ protected[sql] abstract class AtomicType extends DataType { private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val tag: TypeTag[InternalType] private[sql] val ordering: Ordering[InternalType] @transient private[sql] val classTag = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7bcd623b3f33e..4b54c31dcc27a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -51,7 +51,9 @@ abstract class DataType extends AbstractDataType { def defaultSize: Int /** Name of the type used in JSON serialization. */ - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + def typeName: String = { + this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + } private[sql] def jsonValue: JValue = typeName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 11e0c120f4072..4025cbcec1019 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -23,6 +23,8 @@ import java.math.MathContext import scala.util.Random +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -84,6 +86,7 @@ object RandomDataGenerator { * random data generator is defined for that data type. The generated values will use an external * representation of the data type; for example, the random generator for [[DateType]] will return * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. + * For a [[UserDefinedType]] for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated @@ -106,7 +109,22 @@ object RandomDataGenerator { }) case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) - case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case TimestampType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + // DateTimeUtils.toJavaTimestamp takes microsecond. + DateTimeUtils.toJavaTimestamp(milliseconds * 1000) + } + Some(generator) case CalendarIntervalType => Some(() => { val months = rand.nextInt(1000) val ns = rand.nextLong() @@ -159,6 +177,27 @@ object RandomDataGenerator { None } } + case udt: UserDefinedType[_] => { + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + // Because random data generator at here returns scala value, we need to + // convert it to catalyst value to call udt's deserialize. + val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) + + if (maybeSqlTypeGenerator.isDefined) { + val sqlTypeGenerator = maybeSqlTypeGenerator.get + val generator = () => { + val generatedScalaValue = sqlTypeGenerator.apply() + if (generatedScalaValue == null) { + null + } else { + udt.deserialize(toCatalystType(generatedScalaValue)) + } + } + Some(generator) + } else { + None + } + } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733eae03..f4db4da7646f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" ->"123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1877cff1334bd..cde346e99eb17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -89,6 +89,22 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } + test("a && (!a || b)") { + checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + + checkCondition(('a && ('b || !('a) )), ('a && 'b)) + + checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + + checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + } + + test("!(a && b) , !(a || b)") { + checkCondition((!('a && 'b)), (!('a) || !('b))) + + checkCondition(!('a || 'b), (!('a) && !('b))) + } + private val caseInsensitiveAnalyzer = new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 765c1e2dda99f..f76a903dcc9cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util._ * Provides helper methods for comparing plans. */ class PlanTest extends SparkFunSuite { - /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 791c10c3d7ce7..1a687b2374f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1288,15 +1288,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> Stddev, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee31d83cce42c..102b802ad0a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -124,6 +124,9 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min + case "stddev" => Stddev + case "stddev_pop" => StddevPop + case "stddev_samp" => StddevSamp case "sum" => Sum case "count" | "size" => // Turn count(*) into count(1) @@ -283,6 +286,42 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Stddev) + } + + /** + * Compute the population standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevPop) + } + + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevSamp) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d084..75f2979951acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -290,9 +290,9 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "Enables Parquet filter push-down optimization when set to true.") - val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( - key = "spark.sql.parquet.followParquetFormatSpec", - defaultValue = Some(false), + val PARQUET_WRITE_LEGACY_FORMAT = booleanConf( + key = "spark.sql.parquet.writeLegacyFormat", + defaultValue = Some(true), doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + "Spark SQL schema and vice versa.", isPublic = false) @@ -304,8 +304,7 @@ private[spark] object SQLConf { "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + "option must be set in Hadoop Configuration. 2. This option overrides " + - "\"spark.sql.sources.outputCommitterClass\"." - ) + "\"spark.sql.sources.outputCommitterClass\".") val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), @@ -497,7 +496,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) + private[spark] def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 7f3defec3d42e..d4b834adb6e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpres import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} +import scala.util.matching.Regex + private[r] object SQLUtils { def createSQLContext(jsc: JavaSparkContext): SQLContext = { new SQLContext(jsc) @@ -35,14 +37,15 @@ private[r] object SQLUtils { new JavaSparkContext(sqlCtx.sparkContext) } - def toSeq[T](arr: Array[T]): Seq[T] = { - arr.toSeq - } - def createStructType(fields : Seq[StructField]): StructType = { StructType(fields) } + // Support using regex in string interpolation + private[this] implicit class RegexContext(sc: StringContext) { + def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) + } + def getSQLDataType(dataType: String): DataType = { dataType match { case "byte" => org.apache.spark.sql.types.ByteType @@ -58,6 +61,9 @@ private[r] object SQLUtils { case "boolean" => org.apache.spark.sql.types.BooleanType case "timestamp" => org.apache.spark.sql.types.TimestampType case "date" => org.apache.spark.sql.types.DateType + case r"\Aarray<(.*)${elemType}>\Z" => { + org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 531a8244d55d1..ab482a3636121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -339,6 +339,8 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { setField(to, toOrdinal, getField(from, fromOrdinal)) } + + override def clone(v: UTF8String): UTF8String = v.clone() } private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4df53687a0731..4572d5efc92bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -87,7 +87,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: joins.BuildSide) = { + side: joins.BuildSide): Seq[SparkPlan] = { val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil @@ -123,24 +123,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Outer joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys( - LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys( - RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => joins.SortMergeOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil - - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( @@ -156,11 +151,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) if !canBeConvertedToNewAggregation(plan) => execution.Aggregate( partial = false, namedGroupingAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 5c18558f9bde7..e060c06d9e2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -72,7 +72,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def writeKey[T: ClassTag](key: T): SerializationStream = { // The key is only needed on the map side when computing partition ids. It does not need to // be shuffled. - assert(key.isInstanceOf[Int]) + assert(null == key || key.isInstanceOf[Int]) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3f68b05a24f44..bf6d44c098ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ @DeveloperApi case class Sample( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 879fd69863211..f8ef674ed29c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.SerializableConfiguration private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, - @transient job: Job, + @transient private val job: Job, isAppend: Boolean) extends SparkHadoopMapReduceUtil with Logging @@ -47,7 +47,8 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + protected val serializableConf = + new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -89,7 +90,8 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + SparkHadoopUtil.get.getConfigurationFromJobContext(job). + set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -182,7 +184,9 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) + // scalastyle:off jobcontext this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + // scalastyle:on jobcontext } private def setupConf(): Unit = { @@ -222,8 +226,8 @@ private[sql] abstract class BaseWriterContainer( * A writer that writes all of the rows in a partition to a single file. */ private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, + relation: HadoopFsRelation, + job: Job, isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { @@ -286,8 +290,8 @@ private[sql] class DefaultWriterContainer( * writer externally sorts the remaining rows and then writes out them out one file at a time. */ private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, + relation: HadoopFsRelation, + job: Job, partitionColumns: Seq[Attribute], dataColumns: Seq[Attribute], inputSchema: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7a49157d9e72c..8ee0127c3bde8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -81,7 +81,7 @@ private[sql] class JSONRelation( private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) val paths = inputPaths.map(_.getPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 330ba907b2ef9..f65c7bbd6e29d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import scala.collection.Map @@ -89,7 +90,7 @@ private[sql] object JacksonGenerator { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) @@ -99,8 +100,12 @@ private[sql] object JacksonGenerator { case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) + case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) case (ArrayType(ty, _), v: ArrayData) => gen.writeStartArray() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index cd68bd667c5c4..ff4d8c04e8eaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -81,9 +81,37 @@ private[sql] object JacksonParser { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => parser.getFloatValue + case (VALUE_STRING, FloatType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toFloat + } else { + sys.error(s"Cannot parse $value as FloatType.") + } + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => parser.getDoubleValue + case (VALUE_STRING, DoubleType) => + // Special case handling for NaN and Infinity. + val value = parser.getText + val lowerCaseValue = value.toLowerCase() + if (lowerCaseValue.equals("nan") || + lowerCaseValue.equals("infinity") || + lowerCaseValue.equals("-infinity") || + lowerCaseValue.equals("inf") || + lowerCaseValue.equals("-inf")) { + value.toDouble + } else { + sys.error(s"Cannot parse $value as DoubleType.") + } + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) @@ -126,6 +154,9 @@ private[sql] object JacksonParser { case (_, udt: UserDefinedType[_]) => convertField(factory, parser, udt.sqlType) + + case (token, dataType) => + sys.error(s"Failed to parse a value for data type $dataType (current token: $token).") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index dc4ff06df6f22..b32897ee0f491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -72,7 +72,11 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // Called before `prepareForRead()` when initializing Parquet record reader. override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration + val conf = { + // scalastyle:off jobcontext + context.getConfiguration + // scalastyle:on jobcontext + } // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst // schema of this file from its metadata. @@ -117,14 +121,18 @@ private[parquet] object CatalystReadSupport { // Only clips array types with nested type as element type. clipParquetListType(parquetType.asGroupType(), t.elementType) - case t: MapType if !isPrimitiveCatalystType(t.valueType) => - // Only clips map types with nested type as value type. + case t: MapType + if !isPrimitiveCatalystType(t.keyType) || + !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested key type or value type clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) case t: StructType => clipParquetGroup(parquetType.asGroupType(), t) case _ => + // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able + // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging. parquetType } } @@ -204,14 +212,14 @@ private[parquet] object CatalystReadSupport { } /** - * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. The value type - * of the [[MapType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a - * [[StructType]]. Note that key type of any [[MapType]] is always a primitive type. + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or + * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], or + * a [[StructType]]. */ private def clipParquetMapType( parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { - // Precondition of this method, should only be called for maps with nested value types. - assert(!isPrimitiveCatalystType(valueType)) + // Precondition of this method, only handles maps with nested key types or value types. + assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType)) val repeatedGroup = parquetMap.getType(0).asGroupType() val parquetKeyType = repeatedGroup.getType(0) @@ -221,7 +229,7 @@ private[parquet] object CatalystReadSupport { Types .repeatedGroup() .as(repeatedGroup.getOriginalType) - .addField(parquetKeyType) + .addField(clipParquetType(parquetKeyType, keyType)) .addField(clipParquetType(parquetValueType, valueType)) .named(repeatedGroup.getName) @@ -257,7 +265,7 @@ private[parquet] object CatalystReadSupport { private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - val toParquet = new CatalystSchemaConverter(followParquetFormatSpec = true) + val toParquet = new CatalystSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index f17e794b76650..2ff2fda3610b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -123,6 +123,16 @@ private[parquet] class CatalystRowConverter( updater: ParentContainerUpdater) extends CatalystGroupConverter(updater) with Logging { + assert( + parquetType.getFieldCount == catalystType.length, + s"""Field counts of the Parquet schema and the Catalyst schema don't match: + | + |Parquet schema: + |$parquetType + |Catalyst schema: + |${catalystType.prettyJson} + """.stripMargin) + logDebug( s"""Building row converter for the following schema: | diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index a21ab1dbb25d1..cd666f0020aff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -41,34 +41,31 @@ import org.apache.spark.sql.{AnalysisException, SQLConf} * @constructor * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL - * [[StructType]]. + * [[StructType]]. This argument only affects Parquet read path. * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` - * described in Parquet format spec. - * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when - * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and - * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and - * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is - * backwards-compatible with these settings. If this argument is set to `false`, we fallback - * to old style non-standard behaviors. + * described in Parquet format spec. This argument only affects Parquet read path. + * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4 + * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. + * When set to false, use standard format defined in parquet-format spec. This argument only + * affects Parquet write path. */ private[parquet] class CatalystSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - followParquetFormatSpec: Boolean = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get -) { + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - followParquetFormatSpec = conf.followParquetFormatSpec) + writeLegacyParquetFormat = conf.writeLegacyParquetFormat) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, - followParquetFormatSpec = conf.get(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key).toBoolean) + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -371,15 +368,15 @@ private[parquet] class CatalystSchemaConverter( case BinaryType => Types.primitive(BINARY, repetition).named(field.name) - // ===================================== - // Decimals (for Spark version <= 1.4.x) - // ===================================== + // ====================== + // Decimals (legacy mode) + // ====================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and // always store decimals in fixed-length byte arrays. To keep compatibility with these older // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated // by `DECIMAL`. - case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => + case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) @@ -388,13 +385,13 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - // ===================================== - // Decimals (follow Parquet format spec) - // ===================================== + // ======================== + // Decimals (standard mode) + // ======================== // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -404,7 +401,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -413,7 +410,7 @@ private[parquet] class CatalystSchemaConverter( .named(field.name) // Uses FIXED_LEN_BYTE_ARRAY for all other precisions - case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) @@ -422,17 +419,18 @@ private[parquet] class CatalystSchemaConverter( .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - // =================================================== - // ArrayType and MapType (for Spark versions <= 1.4.x) - // =================================================== + // =================================== + // ArrayType and MapType (legacy mode) + // =================================== - // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level - // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is - // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level + // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro + // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element + // field name "array" is borrowed from parquet-avro. + case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => // group (LIST) { // optional group bag { - // repeated element; + // repeated array; // } // } ConversionPatterns.listType( @@ -441,13 +439,13 @@ private[parquet] class CatalystSchemaConverter( Types .buildGroup(REPEATED) // "array_element" is the name chosen by parquet-hive (1.7.0 and prior version) - .addField(convertField(StructField("array_element", elementType, nullable))) - .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + .addField(convertField(StructField("array", elementType, nullable))) + .named("bag")) // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => // group (LIST) { // repeated element; // } @@ -459,7 +457,7 @@ private[parquet] class CatalystSchemaConverter( // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. - case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => // group (MAP) { // repeated group map (MAP_KEY_VALUE) { // required key; @@ -472,11 +470,11 @@ private[parquet] class CatalystSchemaConverter( convertField(StructField("key", keyType, nullable = false)), convertField(StructField("value", valueType, valueContainsNull))) - // ================================================== - // ArrayType and MapType (follow Parquet format spec) - // ================================================== + // ===================================== + // ArrayType and MapType (standard mode) + // ===================================== - case ArrayType(elementType, containsNull) if followParquetFormatSpec => + case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => // group (LIST) { // repeated group list { // element; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 2c6b914328b60..de1fd0166ac5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -53,7 +53,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) + val configuration = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(jobContext) + // scalastyle:on jobcontext + } val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c6bbc392cad4c..8a9c0e733a9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -211,7 +211,11 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) + val conf = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(job) + // scalastyle:on jobcontext + } // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -283,7 +287,7 @@ private[sql] class ParquetRelation( val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat // Parquet row group size. We will use this value as the value for // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value @@ -301,7 +305,7 @@ private[sql] class ParquetRelation( parquetFilterPushDown, assumeBinaryIsString, assumeInt96IsTimestamp, - followParquetFormatSpec) _ + writeLegacyParquetFormat) _ // Create the function to set input paths at the driver side. val setInputPaths = @@ -527,8 +531,8 @@ private[sql] object ParquetRelation extends Logging { parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean, - followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration + writeLegacyParquetFormat: Boolean)(job: Job): Unit = { + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -557,7 +561,7 @@ private[sql] object ParquetRelation extends Logging { // Sets flags for Parquet schema conversion conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) - conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + conf.setBoolean(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, writeLegacyParquetFormat) overrideMinSplitSize(parquetBlockSize, conf) } @@ -572,7 +576,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, job.getConfiguration) + overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) } private[parquet] def readSchema( @@ -582,7 +586,7 @@ private[sql] object ParquetRelation extends Logging { val converter = new CatalystSchemaConverter( sqlContext.conf.isParquetBinaryAsString, sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.followParquetFormatSpec) + sqlContext.conf.writeLegacyParquetFormat) converter.convert(schema) } @@ -716,7 +720,7 @@ private[sql] object ParquetRelation extends Logging { filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val writeLegacyParquetFormat = sqlContext.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) // !! HACK ALERT !! @@ -756,7 +760,7 @@ private[sql] object ParquetRelation extends Logging { new CatalystSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) footers.map { footer => ParquetRelation.readSchemaFromFooter(footer, converter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 142301fe87cb6..b647bb6116afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -123,7 +123,11 @@ private[parquet] object ParquetTypesConverter extends Logging { throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") } val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val conf = { + // scalastyle:off jobcontext + configuration.getOrElse(ContextUtil.getConfiguration(job)) + // scalastyle:on jobcontext + } val fs: FileSystem = origPath.getFileSystem(conf) if (fs == null) { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6c0196c21a0d1..bc255b27502b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.local.LocalNode +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -38,7 +39,7 @@ import org.apache.spark.{SparkConf, SparkEnv} * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[joins] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation { def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by @@ -111,7 +112,11 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. -private[joins] object HashedRelation { +private[execution] object HashedRelation { + + def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { + apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + } def apply( input: Iterator[InternalRow], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 6b7322671d6b4..906f20d2a7289 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -238,7 +238,7 @@ private[joins] class SortMergeJoinScanner( * Advances the streamed input iterator and buffers all rows from the buffered input that * have matching keys. * @return true if the streamed iterator returned a row, false otherwise. If this returns true, - * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer + * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer * join results. */ final def findNextOuterJoinRows(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index dea9e5e580a1e..c117dff9c8b1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.execution.joins +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * :: DeveloperApi :: * Performs an sort merge outer join of two child relations. - * - * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. */ @DeveloperApi case class SortMergeOuterJoin( @@ -52,6 +53,8 @@ case class SortMergeOuterJoin( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -62,6 +65,7 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is partitioned by the streamed input's join keys. case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -71,6 +75,8 @@ case class SortMergeOuterJoin( // For left and right outer joins, the output is ordered by the streamed input's join keys. case LeftOuter => requiredOrders(leftKeys) case RightOuter => requiredOrders(rightKeys) + // there are null rows in both streams, so there is no order + case FullOuter => Nil case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") } @@ -165,6 +171,26 @@ case class SortMergeOuterJoin( new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + numLeftRows, + rightIter = RowIterator.fromScala(rightIter), + numRightRows, + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + case x => throw new IllegalArgumentException( s"SortMergeOuterJoin should not take $x as the JoinType") @@ -173,98 +199,307 @@ case class SortMergeOuterJoin( } } - +/** + * An iterator for outputting rows in left outer join. + */ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var rightIdx: Int = 0 + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) +} + +/** + * An iterator for outputting rows in right outer join. + */ +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) - private def advanceLeft(): Boolean = { - rightIdx = 0 + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withLeft(smjScanner.getStreamedRow) + setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching right rows, so return nulls for the right row - joinedRow.withRight(rightNullRow) + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) } else { - // Find the next row from the right input that satisfied the bound condition - if (!advanceRightUntilBoundConditionSatisfied()) { - joinedRow.withRight(rightNullRow) + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) } } true } else { - // Left input has been exhausted + // Stream has been exhausted false } } - private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) - rightIdx += 1 + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 } foundMatch } override def advanceNext(): Boolean = { - val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() - if (r) numRows += 1 + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 r } override def getRow: InternalRow = resultProj(joinedRow) } -private class RightOuterIterator( - smjScanner: SortMergeJoinScanner, - leftNullRow: InternalRow, +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + numLeftRows: LongSQLMetric, + rightIter: RowIterator, + numRightRows: LongSQLMetric, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { + leftNullRow: InternalRow, + rightNullRow: InternalRow) { private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var leftIdx: Int = 0 - assert(smjScanner.getBufferedMatches.length == 0) + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ - private def advanceRight(): Boolean = { - leftIdx = 0 - if (smjScanner.findNextOuterJoinRows()) { - joinedRow.withRight(smjScanner.getStreamedRow) - if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching left rows, so return nulls for the left row - joinedRow.withLeft(leftNullRow) - } else { - // Find the next row from the left input that satisfied the bound condition - if (!advanceLeftUntilBoundConditionSatisfied()) { - joinedRow.withLeft(leftNullRow) - } - } + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + numLeftRows += 1 true } else { - // Right input has been exhausted + leftRow = null + leftRowKey = null false } } - private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { - var foundMatch: Boolean = false - while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { - foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) - leftIdx += 1 + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + numRightRows += 1 + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false } - foundMatch } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() override def advanceNext(): Boolean = { - val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() + val r = smjScanner.advanceNext() if (r) numRows += 1 r } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala new file mode 100644 index 0000000000000..b31c5a863832e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} + +case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToSafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToSafe = FromUnsafeProjection(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToSafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala new file mode 100644 index 0000000000000..de2f4e661ab44 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} + +case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var convertToUnsafe: Projection = _ + + override def open(): Unit = { + child.open() + convertToUnsafe = UnsafeProjection.create(child.schema) + } + + override def next(): Boolean = child.next() + + override def fetch(): InternalRow = convertToUnsafe(child.fetch()) + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index 81dd37c7da733..dd1113b6726cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { +case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var predicate: (InternalRow) => Boolean = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala new file mode 100644 index 0000000000000..e7b24e3fca2b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -0,0 +1,136 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. + */ +case class HashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) extends BinaryLocalNode(conf) { + + private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (left, leftKeys, right, rightKeys) + case BuildRight => (right, rightKeys, left, leftKeys) + } + + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ + private[this] var currentMatchPosition: Int = -1 + + private[this] var joinRow: JoinedRow = _ + private[this] var resultProjection: (InternalRow) => InternalRow = _ + + private[this] var hashed: HashedRelation = _ + private[this] var joinKeys: Projection = _ + + override def output: Seq[Attribute] = left.output ++ right.output + + private[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(schema)) + } + + private[this] def buildSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + newMutableProjection(buildKeys, buildNode.output)() + } + } + + private[this] def streamSideKeyGenerator: Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(streamedKeys, streamedNode.output) + } else { + newMutableProjection(streamedKeys, streamedNode.output)() + } + } + + override def open(): Unit = { + buildNode.open() + hashed = HashedRelation(buildNode, buildSideKeyGenerator) + streamedNode.open() + joinRow = new JoinedRow + resultProjection = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + joinKeys = streamSideKeyGenerator + } + + override def next(): Boolean = { + currentMatchPosition += 1 + if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { + fetchNextMatch() + } else { + true + } + } + + /** + * Populate `currentHashMatches` with build-side rows matching the next streamed row. + * @return whether matches are found such that subsequent calls to `fetch` are valid. + */ + private def fetchNextMatch(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamedNode.next()) { + currentStreamedRow = streamedNode.fetch() + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashed.get(key) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + + override def fetch(): InternalRow = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + resultProjection(ret) + } + + override def close(): Unit = { + left.close() + right.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 0000000000000..740d485f8d9e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) + extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala index fffc52abf6dd5..401b10a5ed307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { +case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { private[this] var count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 1c4469acbf264..e540ef8555eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Row +import scala.util.control.NonFatal + +import org.apache.spark.Logging +import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode extends TreeNode[LocalNode] { +abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { + + protected val codegenEnabled: Boolean = conf.codegenEnabled + + protected val unsafeEnabled: Boolean = conf.unsafeEnabled + + lazy val schema: StructType = StructType.fromAttributes(output) + + private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") def output: Seq[Attribute] @@ -57,10 +69,15 @@ abstract class LocalNode extends TreeNode[LocalNode] { */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ - def collect(): Seq[Row] = { + final def collect(): Seq[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() @@ -73,17 +90,78 @@ abstract class LocalNode extends TreeNode[LocalNode] { } result } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + } -abstract class LeafLocalNode extends LocalNode { +abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { override def children: Seq[LocalNode] = Seq.empty } -abstract class UnaryLocalNode extends LocalNode { +abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { def child: LocalNode override def children: Seq[LocalNode] = Seq(child) } + +abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { + + def left: LocalNode + + def right: LocalNode + + override def children: Seq[LocalNode] = Seq(left, right) +} + +/** + * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. + */ +private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { + private var nextRow: InternalRow = _ + + override def hasNext: Boolean = { + if (nextRow == null) { + val res = localNode.next() + if (res) { + nextRow = localNode.fetch() + } + res + } else { + true + } + } + + override def next(): InternalRow = { + if (hasNext) { + val res = nextRow + nextRow = null + res + } else { + throw new NoSuchElementException + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 9b8a4fe493026..11529d6dd9b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} -case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { +case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) + extends UnaryLocalNode(conf) { private[this] var project: UnsafeProjection = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 0000000000000..abf3df1c0c2af --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import java.util.Random + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + +/** + * Sample the dataset. + * + * @param conf the SQLConf + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + conf: SQLConf, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val (sampler, _seed) = if (withReplacement) { + val random = new Random(seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result + // of DataFrame + random.nextLong()) + } else { + (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + } + sampler.setSeed(_seed) + iterator = sampler.sample(child.asIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 242cb66e07b7f..b8467f6ae58e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute /** * An operator that scans some local data collection in the form of Scala Seq. */ -case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { +case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) + extends LeafLocalNode(conf) { private[this] var iterator: Iterator[InternalRow] = _ private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 0000000000000..53f1dcc65d8cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + conf: SQLConf, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var projection: Option[Projection] = _ + private[this] var ord: InterpretedOrdering = _ + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala index ba4aa7671aebd..0f2b8303e7372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -case class UnionNode(children: Seq[LocalNode]) extends LocalNode { +case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { override def output: Seq[Attribute] = children.head.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index f0b56c2eb7a53..a4dbd2e1978d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -74,16 +74,14 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") }} }} -
  • - Detail:
    -
    {executionUIData.physicalPlanDescription}
    -
  • val metrics = listener.getExecutionMetrics(executionId) - summary ++ planVisualization(metrics, executionUIData.physicalPlanGraph) + summary ++ + planVisualization(metrics, executionUIData.physicalPlanGraph) ++ + physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for Plan {executionId}
    } @@ -124,4 +122,23 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") private def jobURL(jobId: Long): String = "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + + private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = { +
    + + + Details + +
    + + +
    + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c4..60d9c509104d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c393f6..7b50aad4ad498 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); + List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf5328c..5f9abd4999ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -61,7 +61,7 @@ public void tearDown() { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -90,6 +90,7 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions @@ -119,7 +120,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -161,7 +162,7 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +181,8 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +195,16 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +212,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +221,14 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +236,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26a..3ab4db2a035d3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca9be..4a78dca7fea66 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc39c..9e241f20987c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public void tearDown() { @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index af7590c3d3c17..356d4ff3fa837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.PhysicalRDD + import scala.concurrent.duration._ import scala.language.postfixOps @@ -34,7 +36,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -44,7 +46,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("withColumn doesn't invalidate cached dataframe") { @@ -69,41 +71,41 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + sqlContext.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + sqlContext.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + sqlContext.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -111,103 +113,103 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("too big for memory") { val data = "*" * 1000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(sqlContext.table("bigData").count() === 200000L) + sqlContext.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + sqlContext.table("testData").cache() + assertCached(sqlContext.table("testData")) + sqlContext.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + sqlContext.table("testData").cache() + sqlContext.table("testData").count() + sqlContext.table("testData").unpersist(blocking = true) + assertCached(sqlContext.table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(sqlContext.table("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + sqlContext.uncacheTable("testData") + assert(!sqlContext.isCached("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + assertCached(sqlContext.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + sqlContext.uncacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + sqlContext.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + sqlContext.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -215,7 +217,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -224,14 +226,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -239,14 +241,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -254,7 +256,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -266,7 +268,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -274,7 +276,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -283,46 +285,48 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + sqlContext.table("t1") + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + sqlContext.cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(sqlContext.isCached("t1")) + assert(sqlContext.isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!sqlContext.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + sqlContext.clearCache() + assert(sqlContext.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(sqlContext.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() @@ -331,9 +335,23 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } + + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { + sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") + sqlContext.cacheTable("abc") + + val sparkPlan = sql( + """select a.key, b.key, c.key from + |abc a join abc b on a.key=b.key + |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan + + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) + assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 37738ec5b3c1d..4e988f074b113 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - ctx.createDataFrame(ctx.sparkContext.parallelize( + sqlContext.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -286,7 +286,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -307,7 +307,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -350,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -411,7 +411,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -556,7 +556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -567,7 +567,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("sparkPartitionId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -578,7 +578,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("InputFileName") { withTempPath { dir => - val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) .head.getString(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 72cf7aab0b977..f5ef9ffd7f4f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -66,12 +66,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 3c359dd840ab7..09f7b507670c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -28,19 +28,19 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("UDF on struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(struct($"a").as("s")).select(f($"s.a")).collect() } test("UDF on named_struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() } test("UDF on array") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index e5d7d63441a6b..094efbaeadcd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -24,7 +24,7 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 28bdd6f83b687..6524abcf5e97f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -29,7 +29,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample with replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -38,7 +38,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample without replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -47,7 +47,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -164,7 +164,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("Frequent Items 2") { - val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -182,7 +182,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = ctx.range(0, 100).select((col("id") % 3).as("key")) + val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 284fff184085a..c167999af580e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,6 +22,8 @@ import java.io.File import scala.language.postfixOps import scala.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -345,7 +347,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("replace column using withColumn") { - val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -434,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) @@ -506,7 +508,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -596,7 +598,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() @@ -619,14 +621,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -646,7 +648,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7324 dropDuplicates") { - val testData = sqlContext.sparkContext.parallelize( + val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -869,7 +871,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -887,4 +889,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .select(struct($"b")) .collect() } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 77907e91363ec..7ae12a7895f7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -32,7 +32,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } } @@ -40,7 +40,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test struct type") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val struct = Row(1, 2L, 3.0F, 3.0) - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + val data = sparkContext.parallelize(Seq(Row(1, struct))) val schema = new StructType() .add("a", IntegerType) @@ -60,7 +60,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val innerStruct = Row(1, "abcd") val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) val schema = new StructType() .add("a", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 8d2f45d70308b..78a98798eff64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -52,7 +52,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { try { sqlContext.experimental.extraStrategies = TestStrategy :: Nil - val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f5c5046a8ed88..7a027e13089e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -31,7 +31,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -59,7 +59,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { @@ -138,7 +138,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( @@ -167,7 +167,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -442,7 +442,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index babf8835d2545..eab0fbb196eb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -32,33 +32,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -66,7 +66,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), sql("SHOW TABLes")).foreach { + Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -77,9 +77,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + sqlContext.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3649c2a97b5ef..cada03e9ac6bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -class QueryTest extends PlanTest { +abstract class QueryTest extends PlanTest { + + protected def sqlContext: SQLContext // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -56,18 +58,33 @@ class QueryTest extends PlanTest { * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(df, expectedAnswer) match { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + val currentValue = sqlContext.conf.dataFrameEagerAnalysis + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val partiallyAnalzyedPlan = df.queryExecution.analyzed + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) + fail( + s""" + |Failed to analyze query: $ae + |$partiallyAnalzyedPlan + | + |${stackTraceToString(ae)} + |""".stripMargin) + } + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => } } - protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } - protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 77ccd6f775e50..3ba14d7602a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -57,7 +57,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 7699adadd9cc8..3d2bd236ceead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { @@ -27,58 +27,67 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) + val newContext = new SQLContext(sparkContext) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) - - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + // Set a conf first. + sqlContext.setConf(testKey, testVal) + // Clear the conf. + sqlContext.conf.clear() + // After clear, only overrideConfs used by unit test should be in the SQLConf. + assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + sqlContext.conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() + sqlContext.conf.clear() sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") + assert(sqlContext.getConf("some.property", "0") === "20") sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + assert(sqlContext.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + assert(sqlContext.getConf(key, "0") === vs) sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + assert(sqlContext.getConf(key, "0") === "") - ctx.conf.clear() + sqlContext.conf.clear() } test("deprecated property") { - ctx.conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.conf.numShufflePartitions === 10) + sqlContext.conf.clear() + val original = sqlContext.conf.numShufflePartitions + try{ + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(sqlContext.conf.numShufflePartitions === 10) + } finally { + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + } } test("invalid conf value") { - ctx.conf.clear() + sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 007be12950774..dd88ae3700ab9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -24,7 +24,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { try { - SQLContext.setLastInstantiatedContext(ctx) + SQLContext.setLastInstantiatedContext(sqlContext) } finally { super.afterAll() } @@ -32,18 +32,18 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext { test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + val sqlContext = new SQLContext(sparkContext) + assert(SQLContext.getOrCreate(sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e172b2c264cb..962b100b532c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -147,14 +147,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -196,7 +196,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( + sqlContext.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -215,7 +215,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-6201 IN type conversion") { sqlContext.read.json( - sqlContext.sparkContext.parallelize( + sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), @@ -991,21 +1025,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size) + sql("SET").collect().foreach { row => + val key = row.getString(0) + val value = row.getString(1) + assert( + TestSQLContext.overrideConfs.contains(key), + s"$key should exist in SQLConf.") + assert( + TestSQLContext.overrideConfs(key) === value, + s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.") + } + val overrideConfs = sql("SET").collect() // "set key=val" sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(testKey, testVal) + overrideConfs ++ Seq(Row(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), - Seq( - Row(testKey, testVal), - Row(testKey + testKey, testVal + testVal)) + overrideConfs ++ Seq(Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) ) // "set key" @@ -1342,7 +1385,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") @@ -1385,13 +1428,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } @@ -1412,10 +1455,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1424,7 +1467,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1432,14 +1475,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-4699 case sensitivity SQL query") { sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1452,14 +1495,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1490,6 +1533,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ @@ -1533,7 +1586,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1600,8 +1653,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen updates peak execution memory") { withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { testCodeGen( "SELECT key, count(value) FROM testData GROUP BY key", (1 to 100).map(i => Row(i, 1))) @@ -1660,8 +1712,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { sortTest() } } @@ -1669,7 +1720,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-9511: error with table starting with number") { withTempTable("1one") { - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") .registerTempTable("1one") checkAnswer(sql("select count(num) from 1one"), Row(10)) @@ -1680,7 +1731,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTempPath { dir => val path = dir.getCanonicalPath val df = - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df .write .format("parquet") @@ -1712,9 +1763,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10130 type coercion for IF should have children resolved first") { - val df = Seq((1, 1), (-1, 1)).toDF("key", "value") - df.registerTempTable("src") - checkAnswer( - sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 45d0ee4a8e749..ddab918629645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sqlContext.sparkContext) + val _sqlContext = new SQLContext(sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index b91438baea06f..e12e6bea30260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -268,9 +268,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row(3, 4)) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) + df.selectExpr("length(c)") // int type of the argument is unacceptable } } @@ -284,63 +282,46 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select(format_number($"f", 4)), + val df = sqlContext.range(1) + + checkAnswer( + df.select(format_number(lit(5L), 4)), Row("5.0000")) checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer Row("1.0000")) checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer Row("2.0000")) checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double Row("3.1322")) checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything + df.select(format_number(lit(4), 4)), // not convert anything Row("4.0000")) checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything + df.select(format_number(lit(5L), 4)), // not convert anything Row("5.0000")) checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything + df.select(format_number(lit(6.48173), 4)), // not convert anything Row("6.4817")) checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything + df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything Row("7.1284")) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) + df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable } intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) + df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable } // for testing the mutable state of the expression in code gen. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index eb275af101e2f..e0435a0dba6ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - ctx.dropTempTable("tmp_table") + sqlContext.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => - val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - ctx.dropTempTable("test_table") + sqlContext.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -87,24 +87,24 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) + sqlContext.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) + sqlContext.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = ctx.sparkContext.parallelize( + val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("integerData") @@ -114,7 +114,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -133,7 +133,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -150,10 +150,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) - ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) - ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) - ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) + sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) + sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -172,7 +172,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -181,13 +181,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - ctx.udf.register("intExpected", (x: Int) => x) + sqlContext.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b6d279ae47268..46d87843dfa4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -90,7 +90,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -148,8 +148,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { StructField("vec", new MyDenseVectorUDT, false) )) - val stringRDD = ctx.sparkContext.parallelize(data) - val jsonRDD = ctx.read.schema(schema).json(stringRDD) + val stringRDD = sparkContext.parallelize(data) + val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -157,4 +157,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { Nil ) } + + test("SPARK-10472 UserDefinedType.typeName") { + assert(IntegerType.typeName === "integer") + assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 952637c5f9cb8..cd3644eb9c099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -39,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + sqlContext.cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + sqlContext.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -69,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + sqlContext.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -81,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + sqlContext.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -96,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - ctx.cacheTable("timestamps") + sqlContext.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -108,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + sqlContext.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -177,18 +177,39 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + sqlContext.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + sqlContext.table("InMemoryCache_different_data_types").collect()) + sqlContext.dropTempTable("InMemoryCache_different_data_types") + } + + test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { + val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val cached = df.cache() + // count triggers the caching action. It should not throw. + cached.count() + + // Make sure, the DataFrame is indeed cached. + assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + + // Check result. + checkAnswer( + cached, + sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + ) + + // Drop the cache. + cached.unpersist() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index ab2644eb4581d..6b7401464f46f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -25,32 +25,32 @@ import org.apache.spark.sql.test.SQLTestData._ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - ctx.cacheTable("pruningData") + sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - ctx.uncacheTable("pruningData") + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 8998f5111124c..911d12e93e503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fad93b014c237..cafa1d5154788 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow @@ -31,14 +30,14 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SparkFunSuite with SharedSQLContext { +class PlannerSuite extends SharedSQLContext { import testImplicits._ setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -53,8 +52,8 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("unions are collapsed") { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,33 +80,30 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) - val fields = fieldTypes.zipWithIndex.map { - case (dataType, index) => StructField(s"c${index}", dataType, true) - } :+ StructField("key", IntegerType, true) - val schema = StructType(fields) - val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = ctx.sparkContext.parallelize(row :: Nil) - ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") - - val planned = sql( - """ - |SELECT l.a, l.b - |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan - - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - - ctx.dropTempTable("testLimit") + def checkPlan(fieldTypes: Seq[DataType]): Unit = { + withTempTable("testLimit") { + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + } } - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - val simpleTypes = NullType :: BooleanType :: @@ -124,7 +120,9 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StringType :: BinaryType :: Nil - checkPlan(simpleTypes, newThreshold = 16434) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") { + checkPlan(simpleTypes) + } val complexTypes = ArrayType(DoubleType, true) :: @@ -136,36 +134,37 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false))) :: Nil - checkPlan(complexTypes, newThreshold = 901617) - - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") { + checkPlan(complexTypes) + } } test("InMemoryRelation statistics propagation") { - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) - - testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { + withTempTable("tiny") { + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") - val a = testData.as("a") - val b = ctx.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val a = testData.as("a") + val b = sqlContext.table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + sqlContext.clearCache() + } + } } test("efficient limit -> project -> sort") { { val query = testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) } @@ -175,7 +174,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { // into it. val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index ef6ad59b71fb3..4492e37ad01ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - ctx.createDataFrame(input), + sqlContext.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(ctx) + SparkPlan.currentContext.set(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 8fa77b0fcb7b7..3073d492e613b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext class SortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 5ab8f44faebf6..de45ae4635fb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -31,14 +31,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def _sqlContext: SQLContext - - /** - * Creates a DataFrame from a local Seq of Product. - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - _sqlContext.implicits.localSeqToDataFrameHolder(data) - } + protected def sqlContext: SQLContext /** * Runs the plan and makes sure the answer matches the expected result. @@ -98,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -122,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -149,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, _sqlContext) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -170,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -210,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -238,10 +231,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = _sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 3158458edb831..7a0f0dfd2b7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -29,15 +29,16 @@ import org.apache.spark.sql.types._ * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder override def beforeAll(): Unit = { super.beforeAll() - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { try { - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) } finally { super.afterAll() } @@ -64,8 +65,7 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { } test("sorting updates peak execution memory") { - val sc = ctx.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), @@ -83,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = ctx.createDataFrame( - ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd02c73a26ace..0113d052e338d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.execution -import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark._ /** @@ -40,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { val converter = UnsafeProjection.create(schema) - converter.apply(internalRow) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } } test("toUnsafeRow() test helper method") { @@ -87,4 +97,50 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(!deserializerIter.hasNext) assert(input.closed) } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + var sc: SparkContext = null + var outputFile: File = null + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.shuffle.memoryFraction", "0.0001") + + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 1000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + } { + // Clean up + if (sc != null) { + sc.stop() + } + + // restore the spark env + SparkEnv.set(oldEnv) + + if (outputFile != null) { + outputFile.delete() + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 5fdb82b067728..afda0d29f6d91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -37,7 +37,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { () => new InterpretedMutableProjection(expr, schema) } - val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1174b27732f22..6a18cc6d27138 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -215,7 +215,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +234,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -262,7 +262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -361,7 +361,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -377,7 +377,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -449,7 +449,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -502,7 +502,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -526,7 +526,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -554,7 +554,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -573,9 +573,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +590,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path) + sqlContext.read.schema(schema).json(path) .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) @@ -603,7 +603,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = sqlContext.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -672,7 +672,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -689,7 +689,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -710,7 +710,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -738,7 +738,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -764,7 +764,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -782,7 +782,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -805,7 +805,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -823,64 +823,63 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = ctx.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val jsonDF = sqlContext.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + } + } } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -926,7 +925,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -949,7 +948,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -957,8 +956,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val primTable = sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -970,8 +969,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = sqlContext.read.json(complexFieldAndType1) + val compTable = sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1039,25 +1038,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1078,18 +1077,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1105,24 +1104,21 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try { - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } finally { - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempDir { dir => + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + // order of MapType is not defined + assert(sqlContext.read.parquet(path).count() == 5) + + val df2 = sqlContext.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df2.collect()) + } } } @@ -1142,19 +1138,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d1 = new File(root, "d1=1") // root/dt=1/col1=abc val p1_col1 = makePartition( - ctx.sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abc") // root/dt=1/col1=abd val p2 = makePartition( - ctx.sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abd") - ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2864181cf91d5..713d1da1cb515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext private[json] trait TestJsonData { - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,7 +198,7 @@ private[json] trait TestJsonData { """]""" :: Nil) - lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) - def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 91f3ce4d34c8b..0835bd123049b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -39,12 +39,13 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) - val fs = fsPath.getFileSystem(configuration) + val fs = fsPath.getFileSystem(hadoopConfiguration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { override def accept(path: Path): Boolean = pathFilter(path) }).toSeq.asJava - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + val footers = + ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true) footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 08d2b9dee99b0..cd552e83372f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -101,7 +101,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -119,7 +119,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() @@ -207,7 +207,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("compression codec") { def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala + .readMetaData(new Path(path), Some(hadoopConfiguration)).getBlocks.asScala .flatMap(_.getColumns.asScala) .map(_.getCodec.name()) .distinct @@ -277,14 +277,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("write metadata") { withTempPath { file => val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) + val fs = FileSystem.getLocal(hadoopConfiguration) val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + ParquetTypesConverter.writeMetaData(attributes, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val metaData = ParquetTypesConverter.readMetaData(path, Some(hadoopConfiguration)) val actualSchema = metaData.getFileMetaData.getSchema val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) @@ -355,7 +355,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, + sparkContext.hadoopConfiguration, path, Collections.singletonList( new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) @@ -370,12 +370,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -383,23 +383,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -407,25 +407,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) @@ -436,8 +436,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(message === "Intentional exception for testing purposes") } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } @@ -455,11 +455,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-7837 Do not close output writer twice when commitTask() fails") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Using a output committer that always fail when committing a task, so that both // `commitTask()` and `abortTask()` are invoked. - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) @@ -483,8 +483,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index ed8bafb10c60b..7bac8609e1b91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -517,7 +517,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index a379523d67f80..1c1cfa34ad04b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,6 +22,9 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -30,6 +33,7 @@ import org.apache.spark.util.Utils * A test suite that tests various Parquet queries. */ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +44,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -118,9 +122,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = ctx.read.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -129,12 +133,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -153,9 +157,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -171,19 +175,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -193,7 +197,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) - val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) val df = sqlContext.createDataFrame(rowRDD, schema) df.write.parquet(basePath) @@ -203,9 +207,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("SPARK-10005 Schema merging for nested struct") { - val sqlContext = _sqlContext - import sqlContext.implicits._ - withTempPath { dir => val path = dir.getCanonicalPath @@ -230,54 +231,168 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - test("SPARK-10301 Clipping nested structs in requested schema") { + test("SPARK-10301 requested schema clipping - same schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + } + + // This test case is ignored because of parquet-mr bug PARQUET-370 + ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(null, null))) + } + } + + test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L, null, null))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, null, null, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - df.write.mode("append").parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", new StructType().add("a", LongType, nullable = true), nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Row(Row(0))) + Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - - val df1 = sqlContext + val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - val df2 = sqlContext - .range(1, 2) - .selectExpr("NAMED_STRUCT('b', id, 'c', id) AS s") + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - df1.write.parquet(path) - df2.write.mode(SaveMode.Append).parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", - new StructType() - .add("a", LongType, nullable = true) - .add("c", LongType, nullable = true), - nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Seq( - Row(Row(0, null)), - Row(Row(null, 1)))) + Row(Row(1L, 2L, null))) } + } + test("SPARK-10301 requested schema clipping - deeply nested struct") { withTempPath { dir => val path = dir.getCanonicalPath @@ -306,4 +421,132 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext Row(Row(Seq(Row(0, null))))) } } + + test("SPARK-10301 requested schema clipping - out of order") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, 1, null)), + Row(Row(null, 2, 4)))) + } + } + + test("SPARK-10301 requested schema clipping - schema merging") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df1.write.mode(SaveMode.Append).parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + checkAnswer( + sqlContext + .read + .option("mergeSchema", "true") + .parquet(path) + .selectExpr("s.a", "s.b", "s.c"), + Seq( + Row(0, null, 2), + Row(1, 2, 3))) + } + } + + test("SPARK-10301 requested schema clipping - UDT") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr( + """NAMED_STRUCT( + | 'f0', CAST(id AS STRING), + | 'f1', NAMED_STRUCT( + | 'a', CAST(id + 1 AS INT), + | 'b', CAST(id + 2 AS LONG), + | 'c', CAST(id + 3.5 AS DOUBLE) + | ) + |) AS s + """.stripMargin) + .coalesce(1) + + df.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("f1", new NestedStructUDT, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(NestedStruct(1, 2L, 3.5D)))) + } + } +} + +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + case class NestedStruct(a: Integer, b: Long, c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = + new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(obj: Any): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + obj match { + case n: NestedStruct => + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = { + datum match { + case row: InternalRow => + NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 28c59a4abdd76..f17fb36f25fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -35,32 +34,29 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( testName: String, messageType: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testSchema( testName, StructType.fromAttributes(ScalaReflection.attributesFor[T]), messageType, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } protected def testParquetToCatalyst( testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -78,14 +74,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -99,10 +94,9 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testCatalystToParquet( testName, @@ -110,8 +104,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) testParquetToCatalyst( testName, @@ -119,8 +112,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } } @@ -137,7 +129,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _6; |} """.stripMargin, - binaryAsString = false) + binaryAsString = false, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", @@ -149,7 +143,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | required int64 _4 (INT_64); | optional int32 _5 (DATE); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "string", @@ -158,7 +155,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); |} """.stripMargin, - binaryAsString = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "binary enum as string", @@ -166,7 +165,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional binary _1 (ENUM); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", @@ -176,7 +178,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - standard", @@ -189,7 +194,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - non-standard", @@ -197,11 +204,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - standard", @@ -214,7 +224,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - standard", @@ -228,7 +240,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - non-standard", @@ -241,7 +255,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", @@ -253,7 +270,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - non-standard", @@ -266,7 +285,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array_element { + | optional group array { | required int32 _1; | required double _2; | } @@ -276,7 +295,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - standard", @@ -300,7 +322,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( "optional types", @@ -315,36 +339,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) - - // Parquet files generated by parquet-thrift are already handled by the schema converter, but - // let's leave this test here until both read path and write path are all updated. - ignore("thrift generated parquet schema") { - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchemaInference[( - Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", - """ - |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } - | } - | } - |} - """.stripMargin, - isThriftDerived = true) - } + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) } class ParquetSchemaSuite extends ParquetSchemaTest { @@ -470,7 +467,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with nullable element type - 2", @@ -486,7 +486,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -499,7 +502,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 2", @@ -512,7 +518,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 3", @@ -523,7 +532,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 element; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 4", @@ -544,7 +556,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", @@ -563,7 +578,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", @@ -582,7 +600,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 7 - " + @@ -592,7 +613,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | repeated int32 f1; |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 8 - " + @@ -612,7 +636,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | required int32 c2; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST @@ -633,7 +660,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", @@ -645,11 +674,14 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -666,7 +698,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", @@ -680,7 +714,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Parquet Map to Catalyst MapType @@ -701,7 +738,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 2", @@ -718,7 +758,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", @@ -735,7 +778,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -752,7 +798,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 2", @@ -769,7 +818,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", @@ -786,7 +838,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Catalyst MapType to Parquet Map @@ -808,7 +863,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", @@ -825,7 +882,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -843,7 +903,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", @@ -860,7 +922,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ================================= // Tests for conversion for decimals @@ -873,7 +938,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(1, 0)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(8, 3) - standard", @@ -882,7 +949,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(8, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(9, 3) - standard", @@ -891,7 +960,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(9, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(18, 3) - standard", @@ -900,7 +971,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int64 f1 (DECIMAL(18, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(19, 3) - standard", @@ -909,7 +982,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(1, 0) - prior to 1.4.x", @@ -917,7 +992,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(8, 3) - prior to 1.4.x", @@ -925,7 +1003,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(9, 3) - prior to 1.4.x", @@ -933,7 +1014,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(18, 3) - prior to 1.4.x", @@ -941,7 +1025,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) private def testSchemaClipping( testName: String, @@ -1012,12 +1099,17 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11Type = new StructType().add("f011", DoubleType, nullable = true) - val f01Type = ArrayType(StringType, containsNull = false) + val f00Type = ArrayType(StringType, containsNull = false) + val f01Type = ArrayType( + new StructType() + .add("f011", DoubleType, nullable = true), + containsNull = false) + val f0Type = new StructType() - .add("f00", f01Type, nullable = false) - .add("f01", f11Type, nullable = false) + .add("f00", f00Type, nullable = false) + .add("f01", f01Type, nullable = false) val f1Type = ArrayType(IntegerType, containsNull = true) + new StructType() .add("f0", f0Type, nullable = false) .add("f1", f1Type, nullable = true) @@ -1046,7 +1138,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1061,13 +1153,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1075,7 +1167,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1095,7 +1187,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1110,13 +1202,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1124,7 +1216,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1236,6 +1328,63 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin) + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + testSchemaClipping( "empty requested schema", @@ -1251,4 +1400,160 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), expectedSchema = "message root {}") + + testSchemaClipping( + "disjoint field sets", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = + new StructType() + .add( + "f0", + new StructType() + .add("f02", FloatType, nullable = true) + .add("f03", DoubleType, nullable = true), + nullable = true), + + expectedSchema = + """message root { + | required group f0 { + | optional float f02; + | optional double f03; + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int32 value_f0; + | required int64 value_f1; + | } + | required int32 value; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int64 value_f1; + | required double value_f2; + | } + | required int32 value; + | } + | } + |} + """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 5dbc7d1630f27..442fafb12f200 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - protected def _sqlContext: SQLContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -43,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -55,7 +54,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -67,14 +66,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - _sqlContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 53a0e53fd7719..dcbfdca71acb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - private var sc: SparkContext = null - private var sqlContext: SQLContext = null + protected var sqlContext: SQLContext = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -44,15 +43,14 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { val conf = new SparkConf() .setMaster("local-cluster[2,1,1024]") .setAppName("testing") - sc = new SparkContext(conf) + val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - sc.stop() - sc = null + sqlContext.sparkContext.stop() sqlContext = null } @@ -60,7 +58,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 4c9187a9a7106..e5fd9e277fc61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index cc649b9bd4c45..4174ee055021d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder - private lazy val myUpperCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myUpperCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), Row(3, "C"), @@ -39,8 +40,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myLowerCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), Row(3, "c"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a1a617d7b7398..09e0237a7cc50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -40,8 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, -1.0), @@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } + } - test(s"$testName using SortMergeOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index baa86e320d986..3afd762942bcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), @@ -40,8 +40,8 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), Row(3, 2.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 07209f3779248..a12670e347c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (testData.col("key") % 2) === 0 checkAnswer( testData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), testData.filter(condition).collect() ) } @@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { val condition = (emptyTestData.col("key") % 2) === 0 checkAnswer( emptyTestData, - node => FilterNode(condition.expr, node), + node => FilterNode(conf, condition.expr, node), emptyTestData.filter(condition).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 0000000000000..43b6f06aead88 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,130 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.execution.joins + +class HashJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + + def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { + test(s"$suiteName: inner join with one match per row") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(upperCaseData.col("N").expr), + Seq(lowerCaseData.col("n").expr), + joins.BuildLeft, + node1, + node2) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N").collect() + ) + } + } + + test(s"$suiteName: inner join with multiple matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + x.join(y).where($"x.a" === $"y.a").collect() + ) + } + } + + test(s"$suiteName: inner join, no matches") { + withSQLConf(confPairs: _*) { + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") + checkAnswer2( + x, + y, + wrapForUnsafe( + (node1, node2) => HashJoinNode( + conf, + Seq(x.col("a").expr), + Seq(y.col("a").expr), + joins.BuildLeft, + node1, + node2) + ), + Nil + ) + } + } + + test(s"$suiteName: big inner join, 4 matches per row") { + withSQLConf(confPairs: _*) { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") + + checkAnswer2( + bigDataX, + bigDataY, + wrapForUnsafe( + (node1, node2) => + HashJoinNode( + conf, + Seq(bigDataX.col("key").expr), + Seq(bigDataY.col("key").expr), + joins.BuildLeft, + node1, + node2) + ), + bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) + } + } + } + + joinSuite( + "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 0000000000000..7deaa375fcfc2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class IntersectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("basic") { + val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer2( + input1, + input2, + (node1, node2) => IntersectNode(conf, node1, node2), + input1.intersect(input2).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 523c02f4a6014..3b183902007e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("basic") { checkAnswer( testData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), testData.limit(10).collect() ) } @@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { test("empty") { checkAnswer( emptyTestData, - node => LimitNode(10, node), + node => LimitNode(conf, 10, node), emptyTestData.limit(10).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 0000000000000..b89fa46f8b3b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class LocalNodeSuite extends SparkFunSuite { + private val data = (1 to 100).toArray + + test("basic open, next, fetch, close") { + val node = new DummyLocalNode(data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { i => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 1) + assert(node.fetch().getInt(0) === i) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyLocalNode(data) + val iter = node.asIterator + node.open() + data.foreach { i => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 1) + assert(item.getInt(0) === i) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyLocalNode(data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 1)) + assert(collected.map(_.getInt(0)) === data) + node.close() + } + +} + +/** + * A dummy [[LocalNode]] that just returns one row per integer in the input. + */ +private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { + private var index = Int.MinValue + + def this(input: Array[Int]) { + this(new SQLConf, input) + } + + def isOpen: Boolean = { + index != Int.MinValue + } + + override def output: Seq[Attribute] = { + Seq(AttributeReference("something", IntegerType)()) + } + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + val values = Array(input(index).asInstanceOf[Any]) + new GenericInternalRow(values) + } + + override def close(): Unit = { + index = Int.MinValue + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8..b95d4ea7f8f2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} -class LocalNodeTest extends SparkFunSuite { +class LocalNodeTest extends SparkFunSuite with SharedSQLContext { + + def conf: SQLConf = sqlContext.conf /** * Runs the LocalNode and makes sure the answer matches the expected result. @@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite { protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { new SeqScanNode( + conf, df.queryExecution.sparkPlan.output, df.queryExecution.toRdd.map(_.copy()).collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index ffcf092e2c66a..38e0a230c46d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( testData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), testData.select("value", "key").collect() ) } @@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { val columns = Seq(output(1), output(0)) checkAnswer( emptyTestData, - node => ProjectNode(columns, node), + node => ProjectNode(conf, columns, node), emptyTestData.select("value", "key").collect() ) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala similarity index 55% rename from core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index a4568e849fa13..87a7da453999c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -15,27 +15,26 @@ * limitations under the License. */ -package org.apache.spark.network.nio +package org.apache.spark.sql.execution.local -import java.nio.ByteBuffer +class SampleNodeSuite extends LocalNodeTest { -import scala.collection.mutable.ArrayBuffer + import testImplicits._ -private[nio] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size: Int = if (buffer == null) 0 else buffer.remaining - - lazy val buffers: ArrayBuffer[ByteBuffer] = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) } - ab } - override def toString: String = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } + testSample(withReplacement = true) + testSample(withReplacement = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 0000000000000..ff28b24eeff14 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } + } + + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 34670287c3e1d..eedd7320900f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( testData, testData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), testData.unionAll(testData).collect() ) } @@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { checkAnswer2( emptyTestData, emptyTestData, - (node1, node2) => UnionNode(Seq(node1, node2)), + (node1, node2) => UnionNode(conf, Seq(node1, node2)), emptyTestData.unionAll(emptyTestData).collect() ) } @@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { emptyTestData, emptyTestData, testData, emptyTestData) doCheckAnswer( dfs, - nodes => UnionNode(nodes), + nodes => UnionNode(conf, nodes), dfs.reduce(_.unionAll(_)).collect() ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 80006bf077fe8..6afffae161ef6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -36,7 +36,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") + val l = SQLMetrics.createLongMetric(sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -50,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = ctx.sparkContext.accumulator(0L) + val l = sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -71,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet df.collect() - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => @@ -474,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq === Seq(2L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 80d1e88956949..2bbb41ca777b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("basic") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d8c9a08d84c61..ed710689cc670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -255,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("Basic API") { - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -330,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -340,8 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -349,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -396,7 +396,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 5dc3a2c07b8c7..e23ee6693133b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,13 +22,12 @@ import java.util.Properties import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null @@ -76,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon conn1.close() } - private lazy val sc = ctx.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) private lazy val schema2 = StructType( @@ -91,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -143,14 +141,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 9bc3f6bcf6fce..6fc9febe49707 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,10 +26,8 @@ import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils - class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index d74d29fb0beb0..af04079ec895a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ - private[sql] abstract class DataSourceTest extends QueryTest { - protected def _sqlContext: SQLContext // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(_sqlContext.sparkContext) + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 084d83f6e9bff..5b70d258d6ce3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 79b6e9b45c009..c9791879ec74c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = ctx.range(100).select($"id", lit(1).as("data")) + val df = sqlContext.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = ctx.range(100) + val base = sqlContext.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index f18546b4c2d9b..10d261368993d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null private var df: DataFrame = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 3fc02df954e23..520dea7f7dd92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext } import internalImplicits._ @@ -37,21 +37,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -63,7 +63,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -71,14 +71,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -90,7 +90,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -102,7 +102,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes, 1) :: BinaryData("22".getBytes, 5) :: BinaryData("122".getBytes, 3) :: @@ -113,7 +113,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -125,7 +125,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -135,7 +135,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -143,7 +143,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -154,13 +154,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -168,7 +168,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -178,7 +178,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -188,7 +188,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -197,13 +197,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -212,13 +212,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -226,7 +226,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -234,7 +234,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -246,7 +246,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(sqlContext != null, "attempted to initialize test data before SQLContext.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index dc08306ad9cb4..9214569f18e93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils @@ -47,13 +47,13 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sparkContext = sqlContext.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = _sqlContext.sql _ + protected lazy val sql = sqlContext.sql _ /** * A helper object for importing SQL implicits. @@ -63,7 +63,14 @@ private[sql] trait SQLTestUtils * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext + + // This must live here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } } /** @@ -84,8 +91,8 @@ private[sql] trait SQLTestUtils /** * The Hadoop configuration used by the active [[SQLContext]]. */ - protected def configuration: Configuration = { - _sqlContext.sparkContext.hadoopConfiguration + protected def hadoopConfiguration: Configuration = { + sparkContext.hadoopConfiguration } /** @@ -96,12 +103,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(_sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) - case (key, None) => _sqlContext.conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } @@ -133,7 +140,7 @@ private[sql] trait SQLTestUtils * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(_sqlContext.dropTempTable) + try f finally tableNames.foreach(sqlContext.dropTempTable) } /** @@ -142,7 +149,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - _sqlContext.sql(s"DROP TABLE IF EXISTS $name") + sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -155,12 +162,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - _sqlContext.sql(s"CREATE DATABASE $dbName") + sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -168,8 +175,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - _sqlContext.sql(s"USE $db") - try f finally _sqlContext.sql(s"USE default") + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") } /** @@ -177,7 +184,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(_sqlContext, plan) + DataFrame(sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index d23c6a0732669..963d10eed62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.{ColumnName, SQLContext} +import org.apache.spark.sql.SQLContext /** @@ -36,9 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def ctx: TestSQLContext = _ctx - protected def sqlContext: TestSQLContext = _ctx - protected override def _sqlContext: SQLContext = _ctx + protected def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. @@ -64,15 +62,4 @@ trait SharedSQLContext extends SQLTestUtils { super.afterAll() } } - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - // This must be duplicated here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 92ef2f7d74ba1..10e633f3cde46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -31,13 +31,24 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel new SparkConf().set("spark.sql.testkey", "true"))) } - // Use fewer partitions to speed up testing + // Make sure we set those test specific confs correctly when we create + // the SQLConf as well as when we call clear. protected[sql] override def createSession(): SQLSession = new this.SQLSession() /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + + clear() + + override def clear(): Unit = { + super.clear() + + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } } } @@ -47,6 +58,17 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } private object testData extends SQLTestData { - protected override def _sqlContext: SQLContext = self + protected override def sqlContext: SQLContext = self } } + +private[sql] object TestSQLContext { + + /** + * A map used to store all confs that need to be overridden in sql/core unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5") +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b8da0840ae569..0a5569b0a4446 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -767,7 +767,7 @@ private[hive] case class InsertIntoHiveTable( private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) - (@transient sqlContext: SQLContext) + (@transient private val sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index dc355690852bd..e35468a624c3e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -54,10 +54,10 @@ private[hive] sealed trait TableReader { */ private[hive] class HadoopTableReader( - @transient attributes: Seq[Attribute], - @transient relation: MetastoreRelation, - @transient sc: HiveContext, - @transient hiveExtraConf: HiveConf) + @transient private val attributes: Seq[Attribute], + @transient private val relation: MetastoreRelation, + @transient private val sc: HiveContext, + hiveExtraConf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7856037508412..1fe4cba9571f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -131,6 +131,7 @@ private[hive] class IsolatedClientLoader( name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || + (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index c7651daffe36e..32bddbaeaeaf9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -53,7 +53,7 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { override def otherCopyArgs: Seq[HiveContext] = sc :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8dc796b056a72..29a6f08f40728 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc) extends Logging with SparkHadoopMapRedUtil @@ -163,7 +163,7 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { } private[spark] class SparkHiveDynamicPartitionWriterContainer( - @transient jobConf: JobConf, + jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { @@ -194,10 +194,10 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then // load it with loadDynamicPartitions/loadPartition/loadTable. - val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + val oldMarker = conf.value.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) super.commitJob() - jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } override def getLocalFileWriter(row: InternalRow, schema: StructType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 4eeca9aec12bd..d1f30e188eafb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -25,9 +25,9 @@ import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} +import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, StructTypeInfo} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -89,21 +89,10 @@ private[orc] class OrcOutputWriter( TypeInfoUtils.getTypeInfoFromTypeString( HiveMetastoreTypes.toMetastoreType(dataSchema)) - TypeInfoUtils - .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) - .asInstanceOf[StructObjectInspector] + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] } - // Used to hold temporary `Writable` fields of the next row to be written. - private val reusableOutputBuffer = new Array[Any](dataSchema.length) - - // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.asScala - .zip(dataSchema.fields.map(_.dataType)) - .map { case (ref, dt) => - wrapperFor(ref.getFieldObjectInspector, dt) - }.toArray - // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. private var recordWriterInstantiated = false @@ -127,16 +116,32 @@ private[orc] class OrcOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - override protected[sql] def writeInternal(row: InternalRow): Unit = { + private def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) + while (i < fieldRefs.size) { + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrap( + row.get(i, dataSchema(i).dataType), + fieldRefs.get(i).getFieldObjectInspector, + dataSchema(i).dataType)) i += 1 } + } + + val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + wrapOrcStruct(cachedOrcStruct, structOI, row) recordWriter.write( NullWritable.get(), - serializer.serialize(reusableOutputBuffer, structOI)) + serializer.serialize(cachedOrcStruct, structOI)) } override def close(): Unit = { @@ -203,7 +208,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - job.getConfiguration match { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -259,7 +264,7 @@ private[orc] case class OrcTableScan( maybeStructOI.map { soi => val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + soi.getStructFieldRef(attr.name) -> ordinal }.unzip val unwrappers = fieldRefs.map(unwrapperFor) // Map each tuple to a row object @@ -284,7 +289,7 @@ private[orc] case class OrcTableScan( def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 57fea5d8db343..be335a47dcabd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand @@ -51,6 +51,11 @@ object TestHive // SPARK-8910 .set("spark.ui.enabled", "false"))) +trait TestHiveSingleton { + protected val sqlContext: SQLContext = TestHive + protected val hiveContext: TestHiveContext = TestHive +} + /** * A locally running test instance of Spark's Hive execution engine. * @@ -111,18 +116,28 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) + // Make sure we set those test specific confs correctly when we create + // the SQLConf as well as when we call clear. override protected[sql] def createSession(): SQLSession = { new this.SQLSession() } protected[hive] class SQLSession extends super.SQLSession { - /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + + clear() + + override def clear(): Unit = { + super.clear() + + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } } } @@ -450,3 +465,15 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } } + +private[hive] object TestHiveContext { + + /** + * A map used to store all confs that need to be overridden in sql/hive unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5" + ) +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 019d8a30266e2..b4bf9eef8fca5 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -40,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,7 +52,7 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } @@ -71,7 +71,7 @@ public void tearDown() throws IOException { @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -95,7 +95,7 @@ public void testUDAF() { registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); - List expectedResult = new ArrayList(); + List expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 4192155975c47..c8d272794d10b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -53,7 +53,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -77,7 +77,7 @@ public void setUp() throws IOException { fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -97,7 +97,7 @@ public void tearDown() throws IOException { @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -120,7 +120,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -132,7 +132,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -148,7 +148,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 39d315aaeab57..9adb3780a2c55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest { +class CachedTableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -95,18 +95,18 @@ class CachedTableSuite extends QueryTest { test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestHive.uncacheTable("src") + hiveContext.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.sql("CACHE TABLE src") + sql("CACHE TABLE src") assertCached(table("src")) - assert(TestHive.isCached("src"), "Table 'src' should be cached") + assert(hiveContext.isCached("src"), "Table 'src' should be cached") - TestHive.sql("UNCACHE TABLE src") + sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + assert(!hiveContext.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 30f5313d2b812..cf737836939f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,12 +22,12 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -122,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { test(name) { val error = intercept[AnalysisException] { - quietly(sql(query)) + quietly(hiveContext.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index fb10f8583da99..2e5cae415e54b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { +class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext.implicits._ + import hiveContext.sql + private var testData: DataFrame = _ override def beforeAll() { testData = Seq((1, 2), (2, 4)).toDF("a", "b") - TestHive.registerDataFrameAsTable(testData, "mytable") + hiveContext.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - TestHive.dropTempTable("mytable") + hiveContext.dropTempTable("mytable") } test("rollup") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 52e782768cb75..f621367eb553b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest { +class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c177cbdd991cf..2c98f1c3cc49c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HiveDataFrameWindowSuite extends QueryTest { +class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ + import hiveContext.sql test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 574624d501f22..107457f79ec03 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,18 +19,15 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{Row, SaveMode, SQLContext} -import org.apache.spark.{Logging, SparkFunSuite} - -class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { + import hiveContext.implicits._ test("struct field should accept underscore in sub-column name") { val hiveTypeStr = "struct" @@ -46,14 +43,15 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } test("duplicated metastore relations") { - val df = sql("SELECT * FROM src") + val df = hiveContext.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } -class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive +class DataSourceWithHiveMetastoreCatalogSuite + extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ import testImplicits._ private val testDF = range(1, 3).select( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index fe0db5228de16..5596ec6882ea2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest { - private val ctx = TestHive - override def _sqlContext: SQLContext = ctx +class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton { test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -53,7 +51,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -66,7 +64,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index dc2d85f48624c..97df249bdb6d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLContext, QueryTest} import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.DecimalType @@ -173,6 +173,7 @@ object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") @@ -264,6 +265,7 @@ object SparkSQLConfTest extends Logging { // For this simple test, we do not really clone this object. override def clone: SparkConf = this } + conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) // Run a simple command to make sure all lazy vals in hiveContext get instantiated. @@ -272,20 +274,24 @@ object SparkSQLConfTest extends Logging { } } -object SPARK_9757 extends QueryTest with Logging { +object SPARK_9757 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() .set("spark.sql.hive.metastore.version", "0.13.1") - .set("spark.sql.hive.metastore.jars", "maven")) + .set("spark.sql.hive.metastore.jars", "maven") + .set("spark.ui.enabled", "false")) val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext import hiveContext.implicits._ - import org.apache.spark.sql.functions._ - val dir = Utils.createTempDir() dir.delete() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d33e81227db88..80a61f82fd24f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -24,28 +24,25 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/* Implicits */ -import org.apache.spark.sql.hive.test.TestHive._ - case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ + import hiveContext.sql - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - TestHive.reset() + hiveContext.reset() // Register the testData, which will be used in every test. testData.registerTempTable("testData") } @@ -96,9 +93,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -169,8 +166,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -185,9 +182,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -202,9 +199,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -217,11 +214,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { } test("SPARK-5498:partition schema does not match table schema") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = TestHive.sparkContext.parallelize( + val testDatawithNull = hiveContext.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index d3388a9429e41..579631df772b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row -class ListTablesSuite extends QueryTest with BeforeAndAfterAll { +class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ - import org.apache.spark.sql.hive.test.TestHive.implicits._ - - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 20a50586d5201..bf0db084906cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,15 +22,11 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -39,10 +35,9 @@ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll - with Logging { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 997c667ec0d1b..f16c257ab5ab4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} -class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val _sqlContext: HiveContext = TestHive - private val sqlContext = _sqlContext - - private val df = sqlContext.range(10).coalesce(1) +class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + private lazy val df = sqlContext.range(10).coalesce(1) private def checkTablePath(dbName: String, tableName: String): Unit = { - // val hiveContext = sqlContext.asInstanceOf[HiveContext] - val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) - val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) + val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName assert(metastoreTable.serdeProperties("path") === expectedPath) } @@ -220,7 +216,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") - sqlContext.refreshTable("t") + hiveContext.refreshTable("t") checkAnswer( sqlContext.table("t"), df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) @@ -252,7 +248,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") - sqlContext.refreshTable(s"$db.t") + hiveContext.refreshTable(s"$db.t") checkAnswer( sqlContext.table(s"$db.t"), df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 91d7a48208e8d..49aab85cf1aaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -18,38 +18,20 @@ package org.apache.spark.sql.hive import java.sql.Timestamp -import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.conf.HiveConf -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.{Row, SQLConf, SQLContext} - -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHiveSingleton +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault - - protected override def beforeAll(): Unit = { - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - Locale.setDefault(Locale.US) - } - - override protected def afterAll(): Unit = { - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - } - override protected def logParquetSchema(path: String): Unit = { val schema = readParquetSchema(path, { path => !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 1cc8a93e83411..f542a5a02508c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,22 +18,18 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ -class QueryPartitionSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ - - protected def _sqlContext = ctx - - test("SPARK-5068: query data when path doesn't exist"){ + test("SPARK-5068: query data when path doesn't exist") { withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e4fec7e2c8a2a..6a692d6fce562 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,24 +17,15 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll - import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - private lazy val ctx: HiveContext = { - val ctx = org.apache.spark.sql.hive.test.TestHive - ctx.reset() - ctx.cacheTables = false - ctx - } - - import ctx.sql +class StatisticsSuite extends QueryTest with TestHiveSingleton { + import hiveContext.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -54,9 +45,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { } } - // Ensure session state is initialized. - ctx.parseSql("use default") - assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[HiveNativeCommand]) @@ -80,7 +68,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + hiveContext.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -114,7 +102,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -125,9 +113,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - ctx.analyze("tempTable") + hiveContext.analyze("tempTable") } - ctx.catalog.unregisterTable(Seq("tempTable")) + hiveContext.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -155,8 +143,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -167,8 +155,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) @@ -211,8 +199,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -225,8 +213,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 7ee1c8d13aa3f..3ab4576811194 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive +class UDFSuite extends QueryTest with TestHiveSingleton { test("UDF case insensitive") { - ctx.udf.register("random0", () => { Math.random() }) - ctx.udf.register("RANDOM1", () => { Math.random() }) - ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + hiveContext.udf.register("random0", () => { Math.random() }) + hiveContext.udf.register("RANDOM1", () => { Math.random() }) + hiveContext.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4886a85948367..a73b1bd52c09f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.hive.test.TestHiveSingleton -abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext - import sqlContext.implicits._ +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ var originalUseAggregate2: Boolean = _ @@ -69,7 +65,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be data2.write.saveAsTable("agg2") val emptyDF = sqlContext.createDataFrame( - sqlContext.sparkContext.emptyRDD[Row], + sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") @@ -511,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be }.getMessage assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } - - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.SortBasedAggregate => agg - case agg: aggregate.TungstenAggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) - } } } @@ -597,7 +558,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") } - override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => sqlContext.setConf( "spark.sql.TungstenAggregate.testFallbackStartsAt", @@ -605,6 +566,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? val newActual = DataFrame(sqlContext, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { @@ -626,12 +588,12 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue } // Override it to make sure we call the actually overridden checkAnswer. - override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } // Override it to make sure we call the actually overridden checkAnswer. - override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index b0d3dd44daedc..e38d1eb5779fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -25,8 +25,10 @@ class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") val ts = - new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf)) ts.executeSql("SHOW TABLES").toRdd.collect() ts.executeSql("SELECT * FROM src").toRdd.collect() ts.executeSql("SHOW TABLES").toRdd.collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4d45249d9c6b8..aa95ba94fa873 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { /** * When set, any cache files that result in test failures will be deleted. Used when the test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 11d7a872dff09..94162da4eae1a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{SQLContext, QueryTest} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, @@ -83,7 +80,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils { test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempTable("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - read.json(rdd).registerTempTable("jt") + hiveContext.read.json(rdd).registerTempTable("jt") val outputs = sql( s""" |EXPLAIN EXTENDED diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index efbef68cd4447..0d4c7f86b315a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest { +class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ + test("SPARK-5324 query result of describe command") { - loadTestTable("src") + hiveContext.loadTestTable("src") // register a describe command to be a temp table sql("desc src").registerTempTable("mydesc") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index ba56a8a6b689c..cd055f9eca37e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,11 +21,11 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HivePlanTest extends QueryTest { - import TestHive._ - import TestHive.implicits._ +class HivePlanTest extends QueryTest with TestHiveSingleton { + import hiveContext.sql + import hiveContext.implicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 83f9f3eaa3a5e..fe63ad5683195 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.test.TestHiveContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -1104,18 +1105,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestHiveContext.overrideConfs.size) + val defaults = collectResults(sql("SET")) assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey=$testVal")) } - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assert(hiveconf.get(testKey, "") === testVal) + assertResult(defaults ++ Set(testKey -> testVal))(collectResults(sql("SET"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { + assertResult(defaults ++ Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9c10ffe1113dc..d9ba895e1eceb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils @@ -43,10 +43,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest { +class HiveUDFSuite extends QueryTest with TestHiveSingleton { - import TestHive.{udf, sql} - import TestHive.implicits._ + import hiveContext.{udf, sql} + import hiveContext.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -123,12 +123,12 @@ class HiveUDFSuite extends QueryTest { | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } - val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) - TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) + val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { @@ -137,7 +137,7 @@ class HiveUDFSuite extends QueryTest { sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() + hiveContext.reset() } test("SPARK-2693 udaf aggregates test") { @@ -157,7 +157,7 @@ class HiveUDFSuite extends QueryTest { } test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") @@ -168,11 +168,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - TestHive.reset() + hiveContext.reset() } test("UDFToListString") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") @@ -183,11 +183,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - TestHive.reset() + hiveContext.reset() } test("UDFToListInt") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") @@ -198,11 +198,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - TestHive.reset() + hiveContext.reset() } test("UDFToStringIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + @@ -214,11 +214,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFToIntIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + @@ -230,11 +230,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() @@ -246,11 +246,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - TestHive.reset() + hiveContext.reset() } test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") @@ -261,11 +261,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - TestHive.reset() + hiveContext.reset() } test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") @@ -280,11 +280,11 @@ class HiveUDFSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - TestHive.reset() + hiveContext.reset() } test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: @@ -297,7 +297,7 @@ class HiveUDFSuite extends QueryTest { Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - TestHive.reset() + hiveContext.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ff1d9a2934cc..8126d02335217 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,9 +26,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils @@ -65,12 +63,12 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ test("UDTF") { - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( @@ -509,19 +507,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), - Seq.empty[Row]) + + sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") checkAnswer( sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") checkAnswer( sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) @@ -614,7 +612,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val rowRdd = sparkContext.parallelize(row :: Nil) - TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") + hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -1044,10 +1042,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - TestHive.sql( - s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + sql( + s"ADD JAR ${hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - TestHive.sql( + sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -1097,21 +1095,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = - TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( - TestHive.sql( + sql( """ |select id, concat(year(datef)) |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) - TestHive.dropTempTable("test_SPARK8588") + dropTempTable("test_SPARK8588") } test("SPARK-9371: fix the support for special chars in column names for hive context") { - TestHive.read.json(TestHive.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1142,8 +1140,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("specifying database name for a temporary table is not allowed") { withTempPath { dir => val path = dir.getCanonicalPath - val df = - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df .write .format("parquet") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 9aca40f15ac15..cb8d0fca8e693 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -22,17 +22,14 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest { - - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { + import hiveContext.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -59,7 +56,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -73,7 +70,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -88,7 +85,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -105,7 +102,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index deec0048d24b8..92043d66c914f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,17 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName - import sqlContext._ - import sqlContext.implicits._ + // ORC does not play well with NullType and UDT. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _: UserDefinedType[_] => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -48,7 +55,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.options(Map( + hiveContext.read.options(Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index a46ca9a2c9706..52e09f9496f05 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -38,7 +36,10 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { @@ -58,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally TestHive.dropTempTable(tableName) + try f finally hiveContext.dropTempTable(tableName) } protected def makePartitionDir( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 744d462938141..8bc33fcf5d906 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -287,6 +287,20 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("SPARK-9170: Don't implicitly lowercase of user-provided columns") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) + sqlContext.read.format("orc").load(path).schema("Acol") + intercept[IllegalArgumentException] { + sqlContext.read.format("orc").load(path).schema("acol") + } + checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + (0 until 10).map(Row(_))) + } + } + test("SPARK-8501: Avoids discovery schema from empty ORC files") { withTempPath { dir => val path = dir.getCanonicalPath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 80c38084f293d..7a34cf731b4c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,12 +21,14 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { +abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + var orcTableDir: File = null var orcTableAsDir: File = null @@ -156,7 +158,7 @@ class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -164,7 +166,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index f7ba20ff41d8d..88a0ed511749f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,15 +22,12 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton -private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive - protected val sqlContext = _sqlContext - import sqlContext.implicits._ - import sqlContext.sparkContext +private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { + import testImplicits._ /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 34d3434569f58..d09dd6ad58191 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,15 +19,11 @@ package org.apache.spark.sql.hive import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -58,6 +54,8 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -536,6 +534,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { + import testImplicits._ + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -619,7 +620,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { val conf = Seq( HiveContext.CONVERT_METASTORE_PARQUET.key -> "false", SQLConf.PARQUET_BINARY_AS_STRING.key -> "true", - SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key -> "true") + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") withSQLConf(conf: _*) { sql( @@ -684,9 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index b4640b1616281..dc0531a6d4bc5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,16 +18,13 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 8ca3a17085194..ef37787137d07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,7 +28,13 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - import sqlContext._ + // JSON does not write data of NullType and does not play well with BinaryType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: BinaryType => false + case _: CalendarIntervalType => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -47,7 +53,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -65,14 +71,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { val data = Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil - val df = createDataFrame(sparkContext.parallelize(data), schema) + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } @@ -90,14 +96,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { Row(new BigDecimal("10.02")) :: Row(new BigDecimal("20000.99")) :: Row(new BigDecimal("10000")) :: Nil - val df = createDataFrame(sparkContext.parallelize(data), schema) + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 06dadbb5feab0..e2d754e806403 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -24,14 +24,20 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.{execution, AnalysisException, SaveMode} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + override val dataSourceName: String = "parquet" - import sqlContext._ - import sqlContext.implicits._ + // Parquet does not play well with NullType. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: NullType => false + case _: CalendarIntervalType => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -51,7 +57,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -69,7 +75,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -97,7 +103,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) + checkAnswer(hiveContext.read.format("parquet").load(path), df) } } @@ -107,7 +113,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - range(1, 10) + hiveContext.range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") @@ -125,7 +131,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val summaryPath = new Path(path, "_metadata") val commonSummaryPath = new Path(path, "_common_metadata") - val fs = summaryPath.getFileSystem(configuration) + val fs = summaryPath.getFileSystem(hadoopConfiguration) fs.delete(summaryPath, true) fs.delete(commonSummaryPath, true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index e8975e5f5cd08..a3a124488d983 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -20,12 +20,27 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - import sqlContext._ + // We have a very limited number of supported types at here since it is just for a + // test relation and we do very basic testing at here. + override protected def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: BinaryType => false + // We are using random data generator and the generated strings are not really valid string. + case _: StringType => false + case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442 + case _: CalendarIntervalType => false + case _: DateType => false + case _: TimestampType => false + case _: ArrayType => false + case _: MapType => false + case _: StructType => false + case _: UserDefinedType[_] => false + case _ => true + } test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -44,7 +59,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 527ca7a81cad8..aeaaa3e1c5220 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -68,7 +68,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { - val serialized = row.toSeq.map(_.toString).mkString(",") + val serialized = row.toSeq.map { v => + if (v == null) "" else v.toString + }.mkString(",") recordWriter.write(null, new Text(serialized)) } @@ -112,7 +114,8 @@ class SimpleTextRelation( val fields = dataSchema.map(_.dataType) sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => - Row(record.split(",").zip(fields).map { case (value, dataType) => + Row(record.split(",", -1).zip(fields).map { case (v, dataType) => + val value = if (v == "") null else v // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) val catalystValue = Cast(Literal(value), dataType).eval() // Here we're converting Catalyst values to Scala values to test `needsConversion` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 7966b43596e75..8ffcef85668d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,18 +28,18 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { import sqlContext.implicits._ val dataSourceName: String + protected def supportsDataType(dataType: DataType): Boolean = true + val dataSchema = StructType( Seq( @@ -100,6 +100,83 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + ignore("test all data types") { + withTempPath { file => + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + // TODO: add CalendarIntervalType to here once we can save it out. + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schema, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schema")) + val data = (1 to 10).map { i => + dataGenerator.apply() match { + case row: Row => row + case null => Row.fromSeq(Seq.fill(schema.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 10) + val df = sqlContext.createDataFrame(rdd, schema) + + // All columns that have supported data types of this source. + val supportedColumns = schema.fields.collect { + case StructField(name, dataType, _, _) if supportsDataType(dataType) => name + } + val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) + + val dfToBeSaved = df.selectExpr(selectedColumns: _*) + + // Save the data out. + dfToBeSaved + .write + .format(dataSourceName) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .save(file.getCanonicalPath) + + val loadedDF = + sqlContext + .read + .format(dataSourceName) + .schema(dfToBeSaved.schema) + .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. + .load(file.getCanonicalPath) + .selectExpr(selectedColumns: _*) + + // Read the data back. + checkAnswer( + loadedDF, + dfToBeSaved + ) + } + } + test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) @@ -298,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) @@ -504,17 +594,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } test("SPARK-8578 specified custom output committer will not be used to append data") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) try { val df = sqlContext.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there data already exists, // this append should succeed because we will use the output committer associated @@ -533,12 +623,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } withTempPath { dir => - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there is no existing data, // this append will fail because AlwaysFailOutputCommitter is used when we do append @@ -549,8 +639,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } @@ -570,7 +660,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } test("SPARK-9899 Disable customized output committer when speculation is on") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) @@ -580,7 +670,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { sqlContext.sparkContext.conf.set("spark.speculation", "true") // Uses a customized output committer which always fails - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) @@ -597,8 +687,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index cd5d960369c05..8a6050f5227bf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) +class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName @@ -49,6 +49,8 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) // Reload properties for the checkpoint application since user wants to set a reload property // or spark had changed its value and user wants to set it back. val propertiesToReload = List( + "spark.yarn.app.id", + "spark.yarn.app.attemptId", "spark.driver.host", "spark.driver.port", "spark.master", @@ -319,7 +321,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) - val compressionCodec = CompressionCodec.createCodec(conf) + var readError: Exception = null checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { @@ -330,13 +332,15 @@ object CheckpointReader extends Logging { return Some(cp) } catch { case e: Exception => + readError = e logWarning("Error reading checkpoint from file " + file, e) } }) // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + throw new SparkException( + s"Failed to read checkpoint from directory $checkpointPath", readError) } None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 2c373640d2fd9..dfc569451df86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -170,7 +170,7 @@ private[python] object PythonDStream { */ private[python] abstract class PythonDStream( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -187,7 +187,7 @@ private[python] abstract class PythonDStream( */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -206,7 +206,7 @@ private[python] class PythonTransformedDStream ( private[python] class PythonTransformed2DStream( parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -230,7 +230,7 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - @transient reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -252,8 +252,8 @@ private[python] class PythonStateDStream( */ private[python] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, + preduceFunc: PythonTransformFunction, + @transient private val pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration) extends PythonDStream(parent, preduceFunc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index c358f5b5bd70b..40208a64861fb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index a6c4cd220e42f..95994c983c0cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param ssc_ Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) +abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) extends DStream[T](ssc_) { private[streaming] var lastValidTime: Time = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 186e1bf03a944..002aac9f43617 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { def getReceiver(): Receiver[T] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a2f5d82a79bd3..a2685046e03d4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.io.{NotSerializableException, ObjectOutputStream} +import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag @@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Time, StreamingContext} private[streaming] class QueueInputDStream[T: ClassTag]( - @transient ssc: StreamingContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] @@ -37,8 +37,13 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def readObject(in: ObjectInputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.") + } + private def writeObject(oos: ObjectOutputStream): Unit = { - throw new NotSerializableException("queueStream doesn't support checkpointing") + logWarning("queueStream doesn't support checkpointing") } override def compute(validTime: Time): Option[RDD[T]] = { @@ -52,7 +57,7 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime) { Some(buffer.head) } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + Some(new UnionRDD(context.sc, buffer.toSeq)) } } else if (defaultRDD != null) { Some(defaultRDD) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index e2925b9e03ec3..5a9eda7c12776 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 6c139f32da31d..87c20afd5c13c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.{StreamingContext, Time} * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) +abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) extends InputDStream[T](ssc_) { /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 5ce5b7aae6e69..de84e0c9a498d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class SocketInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index e081ffe46f502..f811784b25c82 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -61,7 +61,7 @@ class WriteAheadLogBackedBlockRDDPartition( * * * @param sc SparkContext - * @param blockIds Ids of the blocks that contains this RDD's data + * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark * executors). If not, then block lookups by the block ids will be skipped. @@ -73,23 +73,23 @@ class WriteAheadLogBackedBlockRDDPartition( */ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient blockIds: Array[BlockId], + sc: SparkContext, + @transient private val _blockIds: Array[BlockId], @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) - extends BlockRDD[T](sc, blockIds) { + extends BlockRDD[T](sc, _blockIds) { require( - blockIds.length == walRecordHandles.length, - s"Number of block Ids (${blockIds.length}) must be " + + _blockIds.length == walRecordHandles.length, + s"Number of block Ids (${_blockIds.length}) must be " + s" same as number of WAL record handles (${walRecordHandles.length})") require( - isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + isBlockIdValid.isEmpty || isBlockIdValid.length == _blockIds.length, s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + - s" same as number of block Ids (${blockIds.length})") + s" same as number of block Ids (${_blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -99,9 +99,9 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) + new WriteAheadLogBackedBlockRDDPartition(i, _blockIds(i), isValid, walRecordHandles(i)) } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13f..c5217149224e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +376,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -782,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1298,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69dec..ec2bffd6a5b97 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index ec2852d9a0206..047e38ef90998 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -76,6 +76,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { fail("Timeout: cannot finish all batches in 30 seconds") } + // Ensure progress listener has been notified of all events + ssc.scheduler.listenerBus.waitUntilEmpty(500) + // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) assert(ssc.progressListener.numTotalProcessedRecords === input.size) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c20..13cfe29d7b304 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -47,7 +47,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + val conf = new SparkConf() + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + .set("spark.app.id", "streaming-test") val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) @@ -184,7 +186,7 @@ class ReceivedBlockHandlerSuite } test("Test Block - isFullyConsumed") { - val sparkConf = new SparkConf() + val sparkConf = new SparkConf().set("spark.app.id", "streaming-test") sparkConf.set("spark.storage.unrollMemoryThreshold", "512") // spark.storage.unrollFraction set to 0.4 for BlockManager sparkConf.set("spark.storage.unrollFraction", "0.4") @@ -251,7 +253,7 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7423ef6bcb6ea..d26894e88fc26 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -726,16 +726,26 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } test("queueStream doesn't support checkpointing") { - val checkpointDir = Utils.createTempDir() - ssc = new StreamingContext(master, appName, batchDuration) - val rdd = ssc.sparkContext.parallelize(1 to 10) - ssc.queueStream[Int](Queue(rdd)).print() - ssc.checkpoint(checkpointDir.getAbsolutePath) - val e = intercept[NotSerializableException] { - ssc.start() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + def creatingFunction(): StreamingContext = { + val _ssc = new StreamingContext(conf, batchDuration) + val rdd = _ssc.sparkContext.parallelize(1 to 10) + _ssc.checkpoint(checkpointDirectory) + _ssc.queueStream[Int](Queue(rdd)).register() + _ssc + } + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + val e = intercept[SparkException] { + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) } // StreamingContext.validate changes the message, so use "contains" here - assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + assert(e.getCause.getMessage.contains("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.")) } def addInputStream(s: StreamingContext): DStream[Int] = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 991b5cec00bd8..93621b44c9183 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -590,6 +590,13 @@ private[spark] class ApplicationMaster( case None => logWarning("Container allocator is not ready to kill executors yet.") } context.reply(true) + + case GetExecutorLossReason(eid) => + Option(allocator) match { + case Some(a) => a.enqueueGetLossReasonRequest(eid, context) + case None => logWarning(s"Container allocator is not ready to find" + + s" executor loss reasons yet.") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e9a02baafd28e..a2c4bc2f5480b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1045,7 +1045,10 @@ object Client extends Logging { s"in favor of the $CONF_SPARK_JAR configuration variable.") System.getenv(ENV_SPARK_JAR) } else { - SparkContext.jarOfClass(this.getClass).head + SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " + + "find jar containing Spark classes. The jar can be defined using the " + + "spark.yarn.jar configuration option. If testing Spark, either set that option or " + + "make sure SPARK_PREPEND_CLASSES is not set.")) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 4f42ffefa77f9..54f62e6b723ac 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -96,6 +96,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) } numExecutors = initialNumExecutors + } else { + val numExecutorsConf = "spark.executor.instances" + numExecutors = sparkConf.getInt(numExecutorsConf, numExecutors) } principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 5f897cbcb4e9f..fd88b8b7fe3b9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,8 +21,9 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -36,8 +37,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.util.Utils /** @@ -93,6 +95,11 @@ private[yarn] class YarnAllocator( sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) } + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a + // list of requesters that should be responded to once we find out why the given executor + // was lost. + private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] @@ -235,9 +242,7 @@ private[yarn] class YarnAllocator( val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers.asScala) - logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) } @@ -429,7 +434,7 @@ private[yarn] class YarnAllocator( for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId val alreadyReleased = releasedContainers.remove(containerId) - if (!alreadyReleased) { + val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -440,22 +445,42 @@ private[yarn] class YarnAllocator( // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { - logInfo("Container preempted: " + containerId) - } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed += 1 + val exitStatus = completedContainer.getExitStatus + val (isNormalExit, containerExitReason) = exitStatus match { + case ContainerExitStatus.SUCCESS => + (true, s"Executor for container $containerId exited normally.") + case ContainerExitStatus.PREEMPTED => + // Preemption should count as a normal exit, since YARN preempts containers merely + // to do resource sharing, and tasks that fail due to preempted executors could + // just as easily finish on any other executor. See SPARK-8167. + (true, s"Container $containerId was preempted.") + // Should probably still count memory exceeded exit codes towards task failures + case VMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + case PMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + case unknown => + numExecutorsFailed += 1 + (false, "Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + + } + if (isNormalExit) { + logInfo(containerExitReason) + } else { + logWarning(containerExitReason) } + ExecutorExited(0, isNormalExit, containerExitReason) + } else { + // If we have already released this container, then it must mean + // that the driver has explicitly requested it to be killed + ExecutorExited(completedContainer.getExitStatus, isNormalExit = true, + s"Container $containerId exited from explicit termination request.") } if (allocatedContainerToHostMap.contains(containerId)) { @@ -474,18 +499,35 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - + pendingLossReasonRequests.remove(eid).foreach { pendingRequests => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) // Notify backend about the failure of the executor numUnexpectedContainerRelease += 1 - driverRef.send(RemoveExecutor(eid, - s"Yarn deallocated the executor $eid (container $containerId)")) + driverRef.send(RemoveExecutor(eid, exitReason)) } } } } + /** + * Register that some RpcCallContext has asked the AM why the executor was lost. Note that + * we can only find the loss reason to send back in the next call to allocateResources(). + */ + private[yarn] def enqueueGetLossReasonRequest( + eid: String, + context: RpcCallContext): Unit = synchronized { + if (executorIdToContainer.contains(eid)) { + pendingLossReasonRequests + .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else { + logWarning(s"Tried to get the loss reason for non-existent executor $eid") + } + } + private def internalReleaseContainer(container: Container): Unit = { releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) @@ -501,6 +543,8 @@ private object YarnAllocator { Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") val VMEM_EXCEEDED_PATTERN = Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + val VMEM_EXCEEDED_EXIT_CODE = -103 + val PMEM_EXCEEDED_EXIT_CODE = -104 def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics)
    Property NameDefaultMeaning
    spark.streaming.backpressure.enabledfalse + Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). + This enables the Spark Streaming to control the receiving rate based on the + current batch scheduling delays and processing times so that the system receives + only as fast as the system can process. Internally, this dynamically sets the + maximum receiving rate of receivers. This rate is upper bounded by the values + `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + if they are set (see below). +
    spark.streaming.blockInterval 200ms