From 208b902257bbfb85bf8cadfc942b7134ad690f8b Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Tue, 12 May 2015 23:44:21 -0700 Subject: [PATCH 001/109] [SPARK-7566][SQL] Add type to HiveContext.analyzer This makes HiveContext.analyzer overrideable. Author: Santiago M. Mola Closes #6086 from smola/patch-3 and squashes the following commits: 8ece136 [Santiago M. Mola] [SPARK-7566][SQL] Add type to HiveContext.analyzer --- .../src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 766c42d040f80..9d98c36e947a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -335,7 +335,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* An analyzer that uses the Hive metastore. */ @transient - override protected[sql] lazy val analyzer = + override protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = catalog.ParquetConversions :: From df9b94a57cbd0e028228059d215b446d59d25ba8 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 12 May 2015 23:52:30 -0700 Subject: [PATCH 002/109] [SPARK-7482] [SPARKR] Rename some DataFrame API methods in SparkR to match their counterparts in Scala. Author: Sun Rui Closes #6007 from sun-rui/SPARK-7482 and squashes the following commits: 5c5cf5e [Sun Rui] Implement alias loadDF() as a new function. 3a30c10 [Sun Rui] Rename load()/save() to read.df()/write.df(). Also add loadDF()/saveDF() as aliases. 9f569d6 [Sun Rui] [SPARK-7482][SparkR] Rename some DataFrame API methods in SparkR to match their counterparts in Scala. --- R/pkg/NAMESPACE | 6 +++-- R/pkg/R/DataFrame.R | 35 +++++++++++++++++----------- R/pkg/R/RDD.R | 4 ++-- R/pkg/R/SQLContext.R | 13 ++++++++--- R/pkg/R/generics.R | 22 +++++++++++------- R/pkg/inst/tests/test_sparkSQL.R | 40 ++++++++++++++++---------------- 6 files changed, 71 insertions(+), 49 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 819e9a24e5c0e..ba29614e7b179 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -37,7 +37,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", - "sampleDF", + "sample", "sample_frac", "saveAsParquetFile", "saveAsTable", @@ -53,7 +53,8 @@ exportMethods("arrange", "unpersist", "where", "withColumn", - "withColumnRenamed") + "withColumnRenamed", + "write.df") exportClasses("Column") @@ -101,6 +102,7 @@ export("cacheTable", "jsonFile", "loadDF", "parquetFile", + "read.df", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 2705817531019..a7fa32e291fb1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -294,8 +294,8 @@ setMethod("registerTempTable", #'\dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") -#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' df <- read.df(sqlCtx, path, "parquet") +#' df2 <- read.df(sqlCtx, path2, "parquet") #' registerTempTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} @@ -473,14 +473,14 @@ setMethod("distinct", dataFrame(sdf) }) -#' SampleDF +#' Sample #' #' Return a sampled subset of this DataFrame using a random seed. #' #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @rdname sampleDF +#' @rdname sample #' @aliases sample_frac #' @export #' @examples @@ -489,10 +489,10 @@ setMethod("distinct", #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlCtx, path) -#' collect(sampleDF(df, FALSE, 0.5)) -#' collect(sampleDF(df, TRUE, 0.5)) +#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, TRUE, 0.5)) #'} -setMethod("sampleDF", +setMethod("sample", # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", @@ -503,13 +503,13 @@ setMethod("sampleDF", dataFrame(sdf) }) -#' @rdname sampleDF -#' @aliases sampleDF +#' @rdname sample +#' @aliases sample setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), function(x, withReplacement, fraction) { - sampleDF(x, withReplacement, fraction) + sample(x, withReplacement, fraction) }) #' Count @@ -1303,7 +1303,7 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname saveAsTable +#' @rdname write.df #' @export #' @examples #'\dontrun{ @@ -1311,9 +1311,9 @@ setMethod("except", #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlCtx, path) -#' saveAsTable(df, "myfile") +#' write.df(df, "myfile", "parquet", "overwrite") #' } -setMethod("saveDF", +setMethod("write.df", signature(df = "DataFrame", path = 'character', source = 'character', mode = 'character'), function(df, path = NULL, source = NULL, mode = "append", ...){ @@ -1334,6 +1334,15 @@ setMethod("saveDF", callJMethod(df@sdf, "save", source, jmode, options) }) +#' @rdname write.df +#' @aliases saveDF +#' @export +setMethod("saveDF", + signature(df = "DataFrame", path = 'character', source = 'character', + mode = 'character'), + function(df, path = NULL, source = NULL, mode = "append", ...){ + write.df(df, path, source, mode, ...) + }) #' saveAsTable #' diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 9138629cac9c0..d3a68fff780ce 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -927,7 +927,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", MAXINT))))) # TODO(zongheng): investigate if this call is an in-place shuffle? - sample(samples)[1:total] + base::sample(samples)[1:total] }) # Creates tuples of the elements in this RDD by applying a function. @@ -996,7 +996,7 @@ setMethod("coalesce", if (shuffle || numPartitions > SparkR:::numPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed - start <- as.integer(sample(numPartitions, 1) - 1) + start <- as.integer(base::sample(numPartitions, 1) - 1) lapply(seq_along(part), function(i) { pos <- (start + i) %% numPartitions diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index cae06e6af2bff..531442e8459e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -421,7 +421,7 @@ clearCache <- function(sqlCtx) { #' \dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") +#' df <- read.df(sqlCtx, path, "parquet") #' registerTempTable(df, "table") #' dropTempTable(sqlCtx, "table") #' } @@ -450,10 +450,10 @@ dropTempTable <- function(sqlCtx, tableName) { #'\dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) -#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' df <- read.df(sqlCtx, "path/to/file.json", source = "json") #' } -loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { +read.df <- function(sqlCtx, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path @@ -462,6 +462,13 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { dataFrame(sdf) } +#' @aliases loadDF +#' @export + +loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { + read.df(sqlCtx, path, source, ...) +} + #' Create an external table #' #' Creates an external table based on the dataset in a data source, diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 557128a419f19..6d2bfb1181e5a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -456,19 +456,19 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sample_frac", +setGeneric("sample", function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + standardGeneric("sample") + }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sampleDF", +setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { - standardGeneric("sampleDF") - }) + standardGeneric("sample_frac") + }) #' @rdname saveAsParquetFile #' @export @@ -480,7 +480,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) -#' @rdname saveAsTable +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") }) + +#' @rdname write.df #' @export setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 99c28830c6237..1109e8fdba3fd 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -209,18 +209,18 @@ test_that("registerTempTable() results in a queryable table and sql() results in }) test_that("insertInto() on a registered table", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", "overwrite") - dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + df <- read.df(sqlCtx, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlCtx, parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") - saveDF(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + df2 <- read.df(sqlCtx, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlCtx, parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") @@ -421,12 +421,12 @@ test_that("distinct() on DataFrames", { expect_true(count(uniques) == 3) }) -test_that("sampleDF on a DataFrame", { +test_that("sample on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) - sampled <- sampleDF(df, FALSE, 1.0) + sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_true(inherits(sampled, "DataFrame")) - sampled2 <- sampleDF(df, FALSE, 0.1) + sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) # Also test sample_frac @@ -491,16 +491,16 @@ test_that("column calculation", { expect_true(count(df2) == 3) }) -test_that("load() from json file", { - df <- loadDF(sqlCtx, jsonPath, "json") +test_that("read.df() from json file", { + df <- read.df(sqlCtx, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) }) -test_that("save() as parquet file", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", mode="overwrite") - df2 <- loadDF(sqlCtx, parquetPath, "parquet") +test_that("write.df() as parquet file", { + df <- read.df(sqlCtx, jsonPath, "json") + write.df(df, parquetPath, "parquet", mode="overwrite") + df2 <- read.df(sqlCtx, parquetPath, "parquet") expect_true(inherits(df2, "DataFrame")) expect_true(count(df2) == 3) }) @@ -670,7 +670,7 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlCtx, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) @@ -712,9 +712,9 @@ test_that("mutate() and rename()", { expect_true(columns(newDF2)[1] == "newerAge") }) -test_that("saveDF() on DataFrame and works with parquetFile", { +test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") + write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlCtx, parquetPath) expect_true(inherits(parquetDF, "DataFrame")) expect_equal(count(df), count(parquetDF)) @@ -722,9 +722,9 @@ test_that("saveDF() on DataFrame and works with parquetFile", { test_that("parquetFile works with multiple input paths", { df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") + write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - saveDF(df, parquetPath2, "parquet", mode="overwrite") + write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) expect_true(inherits(parquetDF, "DataFrame")) expect_true(count(parquetDF) == count(df)*2) From 98195c3031fe60683bb25840f135458d5d0e52c5 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Tue, 12 May 2015 23:55:44 -0700 Subject: [PATCH 003/109] [SPARK-7526] [SPARKR] Specify ip of RBackend, MonitorServer and RRDD Socket server These R process only used to communicate with JVM process on local, so binding to localhost is more reasonable then wildcard ip. Author: linweizhong Closes #6053 from Sephiroth-Lin/spark-7526 and squashes the following commits: 5303af7 [linweizhong] bind to localhost rather than wildcard ip --- core/src/main/scala/org/apache/spark/api/r/RBackend.scala | 6 +++--- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 3a2c94bd9d875..0a91977928cee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetSocketAddress, ServerSocket} +import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -65,7 +65,7 @@ private[spark] class RBackend { } }) - channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() } @@ -101,7 +101,7 @@ private[spark] object RBackend extends Logging { try { // bind to random port val boundPort = sparkRBackend.init() - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // tell the R process via temporary file diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 6fea5e1144f2f..06247f7e8b78c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io._ -import java.net.ServerSocket +import java.net.{InetAddress, ServerSocket} import java.util.{Map => JMap} import scala.collection.JavaConversions._ @@ -55,7 +55,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( val parentIterator = firstParent[T].iterator(partition, context) // we expect two connections - val serverSocket = new ServerSocket(0, 2) + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // The stdout/stderr is shared by multiple tasks, because we use one daemon @@ -414,7 +414,7 @@ private[r] object RRDD { synchronized { if (daemonChannel == null) { // we expect one connections - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort errThread = createRProcess(rLibDir, daemonPort, "daemon.R") // the socket used to send out the input of task From 50c72708015fba15d0e78946f1f4ec262776bc38 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Wed, 13 May 2015 09:43:40 +0100 Subject: [PATCH 004/109] [SPARK-6568] spark-shell.cmd --jars option does not accept the jar that has space in its path escape spaces in the arguments. Author: Masayoshi TSUZUKI Author: Kousuke Saruta Closes #5447 from tsudukim/feature/SPARK-6568-2 and squashes the following commits: 3f9a188 [Masayoshi TSUZUKI] modified some errors. ed46047 [Masayoshi TSUZUKI] avoid scalastyle errors. 1784239 [Masayoshi TSUZUKI] removed Utils.formatPath. e03f289 [Masayoshi TSUZUKI] removed testWindows from Utils.resolveURI and Utils.resolveURIs. replaced SystemUtils.IS_OS_WINDOWS to Utils.isWindows. removed Utils.formatPath from PythonRunner.scala. 84c33d0 [Masayoshi TSUZUKI] - use resolveURI in nonLocalPaths - run tests for Windows path only on Windows 016128d [Masayoshi TSUZUKI] fixed to use File.toURI() 2c62e3b [Masayoshi TSUZUKI] Merge pull request #1 from sarutak/SPARK-6568-2 7019a8a [Masayoshi TSUZUKI] Merge branch 'master' of https://github.com/apache/spark into feature/SPARK-6568-2 45946ee [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-6568-2 10f1c73 [Kousuke Saruta] Added a comment 93c3c40 [Kousuke Saruta] Merge branch 'classpath-handling-fix' of github.com:sarutak/spark into SPARK-6568-2 649da82 [Kousuke Saruta] Fix classpath handling c7ba6a7 [Masayoshi TSUZUKI] [SPARK-6568] spark-shell.cmd --jars option does not accept the jar that has space in its path --- .../apache/spark/deploy/PythonRunner.scala | 23 ++++--- .../scala/org/apache/spark/util/Utils.scala | 43 ++++-------- .../spark/deploy/PythonRunnerSuite.scala | 31 +++++---- .../org/apache/spark/util/UtilsSuite.scala | 67 ++++++++++++------- .../org/apache/spark/repl/SparkILoop.scala | 5 +- 5 files changed, 89 insertions(+), 80 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 53e18c4bcec23..c2ed43a5397d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,9 +18,11 @@ package org.apache.spark.deploy import java.net.URI +import java.io.File import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ +import scala.util.Try import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -81,16 +83,13 @@ object PythonRunner { throw new IllegalArgumentException("Launching Python applications through " + s"spark-submit is currently only supported for local files: $path") } - val windows = Utils.isWindows || testWindows - var formattedPath = if (windows) Utils.formatWindowsPath(path) else path - - // Strip the URI scheme from the path - formattedPath = - new URI(formattedPath).getScheme match { - case null => formattedPath - case Utils.windowsDrive(d) if windows => formattedPath - case _ => new URI(formattedPath).getPath - } + // get path when scheme is file. + val uri = Try(new URI(path)).getOrElse(new File(path).toURI) + var formattedPath = uri.getScheme match { + case null => path + case "file" | "local" => uri.getPath + case _ => null + } // Guard against malformed paths potentially throwing NPE if (formattedPath == null) { @@ -99,7 +98,9 @@ object PythonRunner { // In Windows, the drive should not be prefixed with "/" // For instance, python does not understand "/C:/path/to/sheep.py" - formattedPath = if (windows) formattedPath.stripPrefix("/") else formattedPath + if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) { + formattedPath = formattedPath.stripPrefix("/") + } formattedPath } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index be4db02ab86d0..48843b4ae57c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1704,11 +1704,6 @@ private[spark] object Utils extends Logging { */ val windowsDrive = "([a-zA-Z])".r - /** - * Format a Windows path such that it can be safely passed to a URI. - */ - def formatWindowsPath(path: String): String = path.replace("\\", "/") - /** * Indicates whether Spark is currently running unit tests. */ @@ -1806,37 +1801,24 @@ private[spark] object Utils extends Logging { * If the supplied path does not contain a scheme, or is a relative path, it will be * converted into an absolute path with a file:// scheme. */ - def resolveURI(path: String, testWindows: Boolean = false): URI = { - - // In Windows, the file separator is a backslash, but this is inconsistent with the URI format - val windows = isWindows || testWindows - val formattedPath = if (windows) formatWindowsPath(path) else path - - val uri = new URI(formattedPath) - if (uri.getPath == null) { - throw new IllegalArgumentException(s"Given path is malformed: $uri") - } - - Option(uri.getScheme) match { - case Some(windowsDrive(d)) if windows => - new URI("file:/" + uri.toString.stripPrefix("/")) - case None => - // Preserve fragments for HDFS file name substitution (denoted by "#") - // For instance, in "abc.py#xyz.py", "xyz.py" is the name observed by the application - val fragment = uri.getFragment - val part = new File(uri.getPath).toURI - new URI(part.getScheme, part.getPath, fragment) - case Some(other) => - uri + def resolveURI(path: String): URI = { + try { + val uri = new URI(path) + if (uri.getScheme() != null) { + return uri + } + } catch { + case e: URISyntaxException => } + new File(path).getAbsoluteFile().toURI() } /** Resolve a comma-separated list of paths. */ - def resolveURIs(paths: String, testWindows: Boolean = false): String = { + def resolveURIs(paths: String): String = { if (paths == null || paths.trim.isEmpty) { "" } else { - paths.split(",").map { p => Utils.resolveURI(p, testWindows) }.mkString(",") + paths.split(",").map { p => Utils.resolveURI(p) }.mkString(",") } } @@ -1847,8 +1829,7 @@ private[spark] object Utils extends Logging { Array.empty } else { paths.split(",").filter { p => - val formattedPath = if (windows) formatWindowsPath(p) else p - val uri = new URI(formattedPath) + val uri = resolveURI(p) Option(uri.getScheme).getOrElse("file") match { case windowsDrive(d) if windows => false case "local" | "file" => false diff --git a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala index bb6251fb4bfbe..80f2cc02516fe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy import org.scalatest.FunSuite +import org.apache.spark.util.Utils + class PythonRunnerSuite extends FunSuite { // Test formatting a single path to be added to the PYTHONPATH @@ -28,10 +30,14 @@ class PythonRunnerSuite extends FunSuite { assert(PythonRunner.formatPath("file:///spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:/spark.py") === "/spark.py") assert(PythonRunner.formatPath("local:///spark.py") === "/spark.py") - assert(PythonRunner.formatPath("C:/a/b/spark.py", testWindows = true) === "C:/a/b/spark.py") - assert(PythonRunner.formatPath("/C:/a/b/spark.py", testWindows = true) === "C:/a/b/spark.py") - assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) === - "C:/a/b/spark.py") + if (Utils.isWindows) { + assert(PythonRunner.formatPath("file:/C:/a/b/spark.py", testWindows = true) === + "C:/a/b/spark.py") + assert(PythonRunner.formatPath("C:\\a\\b\\spark.py", testWindows = true) === + "C:/a/b/spark.py") + assert(PythonRunner.formatPath("C:\\a b\\spark.py", testWindows = true) === + "C:/a b/spark.py") + } intercept[IllegalArgumentException] { PythonRunner.formatPath("one:two") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:s3:xtremeFS") } intercept[IllegalArgumentException] { PythonRunner.formatPath("hdfs:/path/to/some.py") } @@ -45,14 +51,15 @@ class PythonRunnerSuite extends FunSuite { Array("/app.py", "/spark.py")) assert(PythonRunner.formatPaths("me.py,file:/you.py,local:/we.py") === Array("me.py", "/you.py", "/we.py")) - assert(PythonRunner.formatPaths("C:/a/b/spark.py", testWindows = true) === - Array("C:/a/b/spark.py")) - assert(PythonRunner.formatPaths("/C:/a/b/spark.py", testWindows = true) === - Array("C:/a/b/spark.py")) - assert(PythonRunner.formatPaths("C:/free.py,pie.py", testWindows = true) === - Array("C:/free.py", "pie.py")) - assert(PythonRunner.formatPaths("lovely.py,C:/free.py,file:/d:/fry.py", testWindows = true) === - Array("lovely.py", "C:/free.py", "d:/fry.py")) + if (Utils.isWindows) { + assert(PythonRunner.formatPaths("C:\\a\\b\\spark.py", testWindows = true) === + Array("C:/a/b/spark.py")) + assert(PythonRunner.formatPaths("C:\\free.py,pie.py", testWindows = true) === + Array("C:/free.py", "pie.py")) + assert(PythonRunner.formatPaths("lovely.py,C:\\free.py,file:/d:/fry.py", + testWindows = true) === + Array("lovely.py", "C:/free.py", "d:/fry.py")) + } intercept[IllegalArgumentException] { PythonRunner.formatPaths("one:two,three") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("two,three,four:five:six") } intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 651ead6ff1de2..61152c29a681f 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -367,51 +367,58 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { } test("resolveURI") { - def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = { + def assertResolves(before: String, after: String): Unit = { // This should test only single paths assume(before.split(",").length === 1) // Repeated invocations of resolveURI should yield the same result - def resolve(uri: String): String = Utils.resolveURI(uri, testWindows).toString + def resolve(uri: String): String = Utils.resolveURI(uri).toString assert(resolve(after) === after) assert(resolve(resolve(after)) === after) assert(resolve(resolve(resolve(after))) === after) // Also test resolveURIs with single paths - assert(new URI(Utils.resolveURIs(before, testWindows)) === new URI(after)) - assert(new URI(Utils.resolveURIs(after, testWindows)) === new URI(after)) + assert(new URI(Utils.resolveURIs(before)) === new URI(after)) + assert(new URI(Utils.resolveURIs(after)) === new URI(after)) } - val cwd = System.getProperty("user.dir") + val rawCwd = System.getProperty("user.dir") + val cwd = if (Utils.isWindows) s"/$rawCwd".replace("\\", "/") else rawCwd assertResolves("hdfs:/root/spark.jar", "hdfs:/root/spark.jar") assertResolves("hdfs:///root/spark.jar#app.jar", "hdfs:/root/spark.jar#app.jar") assertResolves("spark.jar", s"file:$cwd/spark.jar") - assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar#app.jar") - assertResolves("C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) - assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt", testWindows = true) - assertResolves("file:/C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) - assertResolves("file:///C:/path/to/file.txt", "file:/C:/path/to/file.txt", testWindows = true) - assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt", testWindows = true) - intercept[IllegalArgumentException] { Utils.resolveURI("file:foo") } - intercept[IllegalArgumentException] { Utils.resolveURI("file:foo:baby") } + assertResolves("spark.jar#app.jar", s"file:$cwd/spark.jar%23app.jar") + assertResolves("path to/file.txt", s"file:$cwd/path%20to/file.txt") + if (Utils.isWindows) { + assertResolves("C:\\path\\to\\file.txt", "file:/C:/path/to/file.txt") + assertResolves("C:\\path to\\file.txt", "file:/C:/path%20to/file.txt") + } + assertResolves("file:/C:/path/to/file.txt", "file:/C:/path/to/file.txt") + assertResolves("file:///C:/path/to/file.txt", "file:/C:/path/to/file.txt") + assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt") + assertResolves("file:foo", s"file:foo") + assertResolves("file:foo:baby", s"file:foo:baby") } test("resolveURIs with multiple paths") { - def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = { + def assertResolves(before: String, after: String): Unit = { assume(before.split(",").length > 1) - assert(Utils.resolveURIs(before, testWindows) === after) - assert(Utils.resolveURIs(after, testWindows) === after) + assert(Utils.resolveURIs(before) === after) + assert(Utils.resolveURIs(after) === after) // Repeated invocations of resolveURIs should yield the same result - def resolve(uri: String): String = Utils.resolveURIs(uri, testWindows) + def resolve(uri: String): String = Utils.resolveURIs(uri) assert(resolve(after) === after) assert(resolve(resolve(after)) === after) assert(resolve(resolve(resolve(after))) === after) } - val cwd = System.getProperty("user.dir") + val rawCwd = System.getProperty("user.dir") + val cwd = if (Utils.isWindows) s"/$rawCwd".replace("\\", "/") else rawCwd assertResolves("jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2") assertResolves("file:/jar1,file:/jar2", "file:/jar1,file:/jar2") assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") - assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5") - assertResolves("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi", testWindows = true) + assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5,path to/jar6", + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4%23jar5,file:$cwd/path%20to/jar6") + if (Utils.isWindows) { + assertResolves("""hdfs:/jar1,file:/jar2,jar3,C:\pi.py#py.pi,C:\path to\jar4""", + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py%23py.pi,file:/C:/path%20to/jar4") + } } test("nonLocalPaths") { @@ -426,6 +433,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { assert(Utils.nonLocalPaths("local:/spark.jar,file:/smart.jar,family.py") === Array.empty) assert(Utils.nonLocalPaths("hdfs:/spark.jar,s3:/smart.jar") === Array("hdfs:/spark.jar", "s3:/smart.jar")) + assert(Utils.nonLocalPaths("hdfs:/spark.jar,path to/a.jar,s3:/smart.jar") === + Array("hdfs:/spark.jar", "s3:/smart.jar")) assert(Utils.nonLocalPaths("hdfs:/spark.jar,s3:/smart.jar,local.py,file:/hello/pi.py") === Array("hdfs:/spark.jar", "s3:/smart.jar")) assert(Utils.nonLocalPaths("local.py,hdfs:/spark.jar,file:/hello/pi.py,s3:/smart.jar") === @@ -547,7 +556,12 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { val targetDir = new File(tempDir, "target-dir") Files.write("some text", sourceFile, UTF_8) - val path = new Path("file://" + sourceDir.getAbsolutePath) + val path = + if (Utils.isWindows) { + new Path("file:/" + sourceDir.getAbsolutePath.replace("\\", "/")) + } else { + new Path("file://" + sourceDir.getAbsolutePath) + } val conf = new Configuration() val fs = Utils.getHadoopFileSystem(path.toString, conf) @@ -567,7 +581,12 @@ class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { val destInnerFile = new File(destInnerDir, sourceFile.getName) assert(destInnerFile.isFile()) - val filePath = new Path("file://" + sourceFile.getAbsolutePath) + val filePath = + if (Utils.isWindows) { + new Path("file:/" + sourceFile.getAbsolutePath.replace("\\", "/")) + } else { + new Path("file://" + sourceFile.getAbsolutePath) + } val testFileDir = new File(tempDir, "test-filename") val testFileName = "testFName" val testFilefs = Utils.getHadoopFileSystem(filePath.toString, conf) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 488f3a9f33256..2b235525250c2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -206,7 +206,8 @@ class SparkILoop( // e.g. file:/C:/my/path.jar -> C:/my/path.jar SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") } } else { - SparkILoop.getAddedJars + // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). + SparkILoop.getAddedJars.map { jar => new URI(jar).getPath } } // work around for Scala bug val totalClassPath = addedJars.foldLeft( @@ -1109,7 +1110,7 @@ object SparkILoop extends Logging { if (settings.classpath.isDefault) settings.classpath.value = sys.props("java.class.path") - getAddedJars.foreach(settings.classpath.append(_)) + getAddedJars.map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) repl process settings } From 10c546e9d42a0f3fbf45c919e74f62c548ca8347 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 May 2015 07:35:55 -0700 Subject: [PATCH 005/109] [SPARK-7599] [SQL] Don't restrict customized output committers to be subclasses of FileOutputCommitter Author: Cheng Lian Closes #6118 from liancheng/spark-7599 and squashes the following commits: 31e1bd6 [Cheng Lian] Don't restrict customized output committers to be subclasses of FileOutputCommitter --- .../apache/spark/sql/sources/commands.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 8372d2c34acc7..fe8be5b7feeb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -244,7 +244,7 @@ private[sql] abstract class BaseWriterContainer( @transient private val jobContext: JobContext = job // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: FileOutputCommitter = _ + @transient protected var outputCommitter: OutputCommitter = _ @transient private var jobId: JobID = _ @transient private var taskId: TaskID = _ @transient private var taskAttemptId: TaskAttemptID = _ @@ -282,14 +282,18 @@ private[sql] abstract class BaseWriterContainer( initWriters() } - private def newOutputCommitter(context: TaskAttemptContext): FileOutputCommitter = { - outputFormatClass.newInstance().getOutputCommitter(context) match { - case f: FileOutputCommitter => f - case f => sys.error( - s"FileOutputCommitter or its subclass is expected, but got a ${f.getClass.getName}.") + protected def getWorkPath: String = { + outputCommitter match { + // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. + case f: FileOutputCommitter => f.getWorkPath.toString + case _ => outputPath } } + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + outputFormatClass.newInstance().getOutputCommitter(context) + } + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) @@ -339,7 +343,7 @@ private[sql] class DefaultWriterContainer( override protected def initWriters(): Unit = { writer = outputWriterClass.newInstance() - writer.init(outputCommitter.getWorkPath.toString, dataSchema, taskAttemptContext) + writer.init(getWorkPath, dataSchema, taskAttemptContext) } override def outputWriterForRow(row: Row): OutputWriter = writer @@ -381,7 +385,7 @@ private[sql] class DynamicPartitionWriterContainer( }.mkString outputWriters.getOrElseUpdate(partitionPath, { - val path = new Path(outputCommitter.getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR)) + val path = new Path(getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR)) val writer = outputWriterClass.newInstance() writer.init(path.toString, dataSchema, taskAttemptContext) writer From b061bd517a3dc26e7f37a334f49c3465d98334c6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 13 May 2015 23:36:19 +0800 Subject: [PATCH 006/109] [SQL] In InsertIntoFSBasedRelation.insert, log cause before abort job/task. We need to add a log entry before calling `abortTask`/`abortJob`. Otherwise, an exception from `abortTask`/`abortJob` will shadow the real cause. cc liancheng Author: Yin Huai Closes #6105 from yhuai/logCause and squashes the following commits: 8dfe0d8 [Yin Huai] Log cause. --- .../src/main/scala/org/apache/spark/sql/sources/commands.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fe8be5b7feeb9..a294297677d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -121,6 +121,7 @@ private[sql] case class InsertIntoFSBasedRelation( writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => + logError("Aborting job.", cause) writerContainer.abortJob() throw new SparkException("Job aborted.", cause) } @@ -143,6 +144,7 @@ private[sql] case class InsertIntoFSBasedRelation( } writerContainer.commitTask() } catch { case cause: Throwable => + logError("Aborting task.", cause) writerContainer.abortTask() throw new SparkException("Task failed while writing rows.", cause) } From aa6ba3f2166edcc8bcda3abc70482fa8605e83b7 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 May 2015 23:40:13 +0800 Subject: [PATCH 007/109] [MINOR] [SQL] Removes debugging println Author: Cheng Lian Closes #6123 from liancheng/remove-println and squashes the following commits: 03356b6 [Cheng Lian] Removes debugging println --- .../org/apache/spark/sql/sources/FSBasedRelationSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala index 415b1cd168848..e8b48a0db1c79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala @@ -509,8 +509,6 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } - println(df.queryExecution) - val actualPaths = df.queryExecution.analyzed.collectFirst { case LogicalRelation(relation: FSBasedRelation) => relation.paths.toSet From 0da254fb2903c01e059fa7d0dc81df5740312b35 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 14 May 2015 00:14:59 +0800 Subject: [PATCH 008/109] [SPARK-6734] [SQL] Add UDTF.close support in Generate Some third-party UDTF extensions generate additional rows in the "GenericUDTF.close()" method, which is supported / documented by Hive. https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF However, Spark SQL ignores the "GenericUDTF.close()", and it causes bug while porting job from Hive to Spark SQL. Author: Cheng Hao Closes #5383 from chenghao-intel/udtf_close and squashes the following commits: 98b4e4b [Cheng Hao] Support UDTF.close --- .../sql/catalyst/expressions/generators.scala | 6 +++ .../apache/spark/sql/execution/Generate.scala | 38 +++++++++++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 18 +++++++-- sql/hive/src/test/resources/TestUDTF.jar | Bin 0 -> 1328 bytes ...l Views-0-ac5c96224a534f07b49462ad76620678 | 2 + ... SELECT-0-517f834fef35b896ec64399f42b2a151 | 2 + .../sql/hive/execution/HiveQuerySuite.scala | 21 ++++++++++ 7 files changed, 74 insertions(+), 13 deletions(-) create mode 100644 sql/hive/src/test/resources/TestUDTF.jar create mode 100644 sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 create mode 100644 sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9a6cb048af5ad..747a47bdde953 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -56,6 +56,12 @@ abstract class Generator extends Expression { /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] + + /** + * Notifies that there are no more rows to process, clean up code, and additional + * rows can be made here. + */ + def terminate(): TraversableOnce[Row] = Nil } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 08d9079335132..dd02c1f4573bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -21,6 +21,18 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +/** + * For lazy computing, be sure the generator.terminate() called in the very last + * TODO reusing the CompletionIterator? + */ +private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row]) + extends Iterator[Row] { + + lazy val results = func().toIterator + override def hasNext: Boolean = results.hasNext + override def next(): Row = results.next() +} + /** * :: DeveloperApi :: * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the @@ -47,27 +59,33 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[Row] = { + // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) - // Used to produce rows with no matches when outer = true. - val outerProjection = - newProjection(child.output ++ nullValues, child.output) - - val joinProjection = newProjection(output, output) + val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow - iter.flatMap {row => + iter.flatMap { row => + // we should always set the left (child output) + joinedRow.withLeft(row) val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { - outerProjection(row) :: Nil + joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinProjection(joinedRow(row, or))) + outputRows.map(or => joinedRow.withRight(or)) } + } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + // we leave the left side as the last element of its child output + // keep it the same as Hive does + joinedRow.withRight(row) } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) + child.execute().mapPartitions { iter => + iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate()) + } } } } + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fd0b6f058595d..bc6b3a2d58c38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -483,7 +483,11 @@ private[hive] case class HiveGenericUdtf( extends Generator with HiveInspectors { @transient - protected lazy val function: GenericUDTF = funcWrapper.createFunction() + protected lazy val function: GenericUDTF = { + val fun: GenericUDTF = funcWrapper.createFunction() + fun.setCollector(collector) + fun + } @transient protected lazy val inputInspectors = children.map(toInspector) @@ -494,6 +498,9 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) + @transient + protected lazy val collector = new UDTFCollector + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } @@ -502,8 +509,7 @@ private[hive] case class HiveGenericUdtf( outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - val collector = new UDTFCollector - function.setCollector(collector) + function.process(wrap(inputProjection(input), inputInspectors, udtInput)) collector.collectRows() } @@ -525,6 +531,12 @@ private[hive] case class HiveGenericUdtf( } } + override def terminate(): TraversableOnce[Row] = { + outputInspector // Make sure initialized. + function.close() + collector.collectRows() + } + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } diff --git a/sql/hive/src/test/resources/TestUDTF.jar b/sql/hive/src/test/resources/TestUDTF.jar new file mode 100644 index 0000000000000000000000000000000000000000..514f2d5d26fd358ad5647e0e75edb8ce77b69e30 GIT binary patch literal 1328 zcmWIWW@Zs#;Nak3xSr7K#()Gk8CV#6T|*poJ^kGD|D9rBU}gyLX6FE@V1g9+ zP&U6PT|coPF*zeuzqlZ=C|kd{Fh@Tlvn*9VwIVgSv?Mb>Pv1Q?FSRH$In*V@%{jj` zuf#|%IVZ8WcxlK>KjA=;e|ow+J<2CuHJwts5Oyj+z|m2lF^e}kXqra-9v}VG4QH-I z3e+&hyZ7o`e8K!9;8^kPO%YMTk5r1!#h&{#_wB!5f8VM%C>{yDve5geZOQcoVzyjW zs{AWm4sj`IiIjvs-uI-|v|~|xYaioEfnb9b2_6TT+q@2c)9qham^b@gU-(}S?-_u{zB$yU{&~h%g<;U5!%kr`i_rG&( z1=D`Ixp}xt-d0_A?wzF4>zUi0YDp~H8)9C0?2p5AdFR)~+a_PDoF*OZbpMg&?Q^f9 zpRKzP`*6vT1$oc^%$hb^HO-n;rgQS&h#!-Fd7l5syZ%O3Q#+5>wT-+Fck$>h6k2#X zcLjSfwFen_YXhGm z^=`M1UD-x;p7>=k`xz!GhN&ekd9&Q?Yf0SWJ!;Oue2X$)dsxC;)*?8E^F7HIkJ4tmi@{)@8^X5ZB1KR^!8JazLNx?@F#6Qh=oqF?F^3mU?`(NA31m#>^_H^wjyAyleIZG~N zN3cs-gj`tdxtaav#WjMwmn?hLF23F(t;XrL$Fhh;j8kUuBZfwg5A03sEB^S1^O$#B zQ*zkSJH>jz!pVO5lRCTRbb9WZtQMslA1J9HGa0HDESi2z@s87s ze>SVae&s**JNU+?Q1kX**;A{PL(Yr;aEf<6t9WJW3-OqVd-SV#-d sql("USE default") From bec938f777a2e18757c7d04504d86a5342e2b49e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 13 May 2015 10:01:26 -0700 Subject: [PATCH 009/109] [SPARK-7589] [STREAMING] [WEBUI] Make "Input Rate" in the Streaming page consistent with other pages This PR makes "Input Rate" in the Streaming page consistent with Job and Stage pages. ![screen shot 2015-05-12 at 5 03 35 pm](https://cloud.githubusercontent.com/assets/1000778/7601444/f943f8ac-f8ca-11e4-8280-a715d814f434.png) ![screen shot 2015-05-12 at 5 07 25 pm](https://cloud.githubusercontent.com/assets/1000778/7601445/f9571c0c-f8ca-11e4-9b12-9317cb55c002.png) Author: zsxwing Closes #6102 from zsxwing/SPARK-7589 and squashes the following commits: 2745225 [zsxwing] Make "Input Rate" in the Streaming page consistent with other pages --- .../apache/spark/ui/static/streaming-page.css | 4 +++ .../apache/spark/ui/static/streaming-page.js | 18 ++++++++++--- .../spark/streaming/ui/StreamingPage.scala | 27 +++++++++---------- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css index 5da9d631ad124..19abe889ad3c1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css +++ b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css @@ -56,3 +56,7 @@ .histogram { width: auto; } + +span.expand-input-rate { + cursor: pointer; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js index a4e03b156f13e..22b186873e990 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js +++ b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js @@ -266,9 +266,19 @@ $(function() { } } - if (getParameterFromURL("show-streams-detail") == "true") { - // Show the details for all InputDStream - $('#inputs-table').toggle('collapsed'); - $('#triangle').html('▼'); + var status = getParameterFromURL("show-streams-detail") == "true"; + + $("span.expand-input-rate").click(function() { + status = !status; + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); + window.history.pushState('', document.title, window.location.pathname + '?show-streams-detail=' + status); + }); + + if (status) { + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); } }); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index ff0f2b18dc321..efce8c58fb962 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -244,17 +244,6 @@ private[ui] class StreamingPage(parent: StreamingTab) val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) val minEventRate = 0L - // JavaScript to show/hide the InputDStreams sub table. - val triangleJs = - s"""$$('#inputs-table').toggle('collapsed'); - |var status = false; - |if ($$(this).html() == '$BLACK_RIGHT_TRIANGLE_HTML') { - |$$(this).html('$BLACK_DOWN_TRIANGLE_HTML');status = true;} - |else {$$(this).html('$BLACK_RIGHT_TRIANGLE_HTML');status = false;} - |window.history.pushState('', - | document.title, window.location.pathname + '?show-streams-detail=' + status);""" - .stripMargin.replaceAll("\\n", "") // it must be only one single line - val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) val jsCollector = new JsCollector @@ -326,10 +315,18 @@ private[ui] class StreamingPage(parent: StreamingTab)
- {if (hasStream) { - {Unparsed(BLACK_RIGHT_TRIANGLE_HTML)} - }} - Input Rate + { + if (hasStream) { + + + + Input Rate + + + } else { + Input Rate + } + }
Avg: {eventRateForAllStreams.formattedAvg} events/sec
From 7ff16e8abef9fbf4a4855e23c256b22e62e560a6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 13 May 2015 11:04:10 -0700 Subject: [PATCH 010/109] [SPARK-7567] [SQL] Migrating Parquet data source to FSBasedRelation This PR migrates Parquet data source to the newly introduced `FSBasedRelation`. `FSBasedParquetRelation` is created to replace `ParquetRelation2`. Major differences are: 1. Partition discovery code has been factored out to `FSBasedRelation` 1. `AppendingParquetOutputFormat` is not used now. Instead, an anonymous subclass of `ParquetOutputFormat` is used to handle appending and writing dynamic partitions 1. When scanning partitioned tables, `FSBasedParquetRelation.buildScan` only builds an `RDD[Row]` for a single selected partition 1. `FSBasedParquetRelation` doesn't rely on Catalyst expressions for filter push down, thus it doesn't extend `CatalystScan` anymore After migrating `JSONRelation` (which extends `CatalystScan`), we can remove `CatalystScan`. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/6090) Author: Cheng Lian Closes #6090 from liancheng/parquet-migration and squashes the following commits: 6063f87 [Cheng Lian] Casts to OutputCommitter rather than FileOutputCommtter bfd1cf0 [Cheng Lian] Fixes compilation error introduced while rebasing f9ea56e [Cheng Lian] Adds ParquetRelation2 related classes to MiMa check whitelist 261d8c1 [Cheng Lian] Minor bug fix and more tests db65660 [Cheng Lian] Migrates Parquet data source to FSBasedRelation --- project/MimaExcludes.scala | 6 + .../org/apache/spark/sql/SQLContext.scala | 8 +- .../spark/sql/parquet/ParquetFilters.scala | 278 +++--- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../spark/sql/parquet/fsBasedParquet.scala | 565 ++++++++++++ .../apache/spark/sql/parquet/newParquet.scala | 840 ------------------ .../apache/spark/sql/sources/commands.scala | 19 +- .../sql/parquet/ParquetFilterSuite.scala | 6 +- .../spark/sql/parquet/ParquetIOSuite.scala | 6 +- .../ParquetPartitionDiscoverySuite.scala | 10 +- .../sql/parquet/ParquetSchemaSuite.scala | 12 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 25 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 15 +- .../sql/hive/execution/SQLQuerySuite.scala | 16 +- .../apache/spark/sql/hive/parquetSuites.scala | 35 +- ...uite.scala => fsBasedRelationSuites.scala} | 173 ++-- 16 files changed, 926 insertions(+), 1090 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala rename sql/hive/src/test/scala/org/apache/spark/sql/sources/{FSBasedRelationSuite.scala => fsBasedRelationSuites.scala} (83%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a47e29e2ef365..f31f0e554eee9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -111,6 +111,12 @@ object MimaExcludes { "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), // These test support classes were moved out of src/main and into src/test: ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetTestData"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 975498c11fa23..0a148c7cd2d3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -27,9 +27,11 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import com.google.common.reflect.TypeToken +import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -42,6 +44,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, e import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ +import org.apache.spark.sql.parquet.FSBasedParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -641,7 +644,10 @@ class SQLContext(@transient val sparkContext: SparkContext) if (paths.isEmpty) { emptyDataFrame } else if (conf.parquetUseDataSourceApi) { - baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this)) + val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray + baseRelationToDataFrame( + new FSBasedParquetRelation( + globbedPaths.map(_.toString), None, None, Map.empty[String, String])(this)) } else { DataFrame(this, parquet.ParquetRelation( paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 5eb1c6abc2432..f0f4e7d147e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -29,128 +29,184 @@ import parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.sources import org.apache.spark.sql.types._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get) + filterExpressions.flatMap { filter => + createFilter(filter) + }.reduceOption(FilterApi.and).map(FilterCompat.get) } - def createFilter(predicate: Expression): Option[FilterPredicate] = { - val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case BooleanType => + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + case IntegerType => + (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } + // Binary.fromString and Binary.fromByteArray don't accept null values + case StringType => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + case BinaryType => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + } - val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case BooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + case IntegerType => + (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + case BinaryType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + } - val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } - val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } - val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } + + private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `SimplifyFilters` rule for details. + predicate match { + case sources.IsNull(name) => + makeEq.lift(dataTypeOf(name)).map(_(name, null)) + case sources.IsNotNull(name) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) + + case sources.EqualTo(name, value) => + makeEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.Not(sources.EqualTo(name, value)) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.LessThan(name, value) => + makeLt.lift(dataTypeOf(name)).map(_(name, value)) + case sources.LessThanOrEqual(name, value) => + makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.GreaterThan(name, value) => + makeGt.lift(dataTypeOf(name)).map(_(name, value)) + case sources.GreaterThanOrEqual(name, value) => + makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.And(lhs, rhs) => + (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) + + case sources.Or(lhs, rhs) => + for { + lhsFilter <- createFilter(schema, lhs) + rhsFilter <- createFilter(schema, rhs) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilter(schema, pred).map(FilterApi.not) + + case _ => None } + } + /** + * Converts Catalyst predicate expressions to Parquet filter predicates. + * + * @todo This can be removed once we get rid of the old Parquet support. + */ + def createFilter(predicate: Expression): Option[FilterPredicate] = { // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -170,7 +226,7 @@ private[sql] object ParquetFilters { makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => makeEq.lift(dataType).map(_(name, value)) - + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => @@ -192,7 +248,7 @@ private[sql] object ParquetFilters { case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => @@ -201,7 +257,7 @@ private[sql] object ParquetFilters { case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => @@ -210,7 +266,7 @@ private[sql] object ParquetFilters { case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 75ac52d4a98ff..90950f924a054 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -674,7 +674,7 @@ private[parquet] object FileSystemHelper { def findMaxTaskId(pathStr: String, conf: Configuration): Int = { val files = FileSystemHelper.listFiles(pathStr, conf) // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-r-(\d{1,}).parquet""", "taskid") + val nameP = new scala.util.matching.Regex("""part-.-(\d{1,}).*""", "taskid") val hiddenFileP = new scala.util.matching.Regex("_.*") files.map(_.getName).map { case nameP(taskid) => taskid.toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala new file mode 100644 index 0000000000000..d810d6a028c58 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.util.{List => JList} + +import scala.collection.JavaConversions._ +import scala.util.Try + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import parquet.filter2.predicate.FilterApi +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop._ +import parquet.hadoop.metadata.CompressionCodecName +import parquet.hadoop.util.ContextUtil + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.RDD._ +import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} + +private[sql] class DefaultSource extends FSBasedRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): FSBasedRelation = { + val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty)) + new FSBasedParquetRelation(paths, schema, partitionSpec, parameters)(sqlContext) + } +} + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter extends OutputWriter { + private var recordWriter: RecordWriter[Void, Row] = _ + private var taskAttemptContext: TaskAttemptContext = _ + + override def init( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): Unit = { + val conf = context.getConfiguration + val outputFormat = { + // When appending new Parquet files to an existing Parquet file directory, to avoid + // overwriting existing data files, we need to find out the max task ID encoded in these data + // file names. + // TODO Make this snippet a utility function for other data source developers + val maxExistingTaskId = { + // Note that `path` may point to a temporary location. Here we retrieve the real + // destination path from the configuration + val outputPath = new Path(conf.get("spark.sql.sources.output.path")) + val fs = outputPath.getFileSystem(conf) + + if (fs.exists(outputPath)) { + // Pattern used to match task ID in part file names, e.g.: + // + // part-r-00001.gz.part + // ^~~~~ + val partFilePattern = """part-.-(\d{1,}).*""".r + + fs.listStatus(outputPath).map(_.getPath.getName).map { + case partFilePattern(id) => id.toInt + case name if name.startsWith("_") => 0 + case name if name.startsWith(".") => 0 + case name => sys.error( + s"""Trying to write Parquet files to directory $outputPath, + |but found items with illegal name "$name" + """.stripMargin.replace('\n', ' ').trim) + }.reduceOption(_ max _).getOrElse(0) + } else { + 0 + } + } + + new ParquetOutputFormat[Row]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate output file name based on the max available + // task ID computed above. + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 + new Path(path, f"part-r-$split%05d$extension") + } + } + } + + recordWriter = outputFormat.getRecordWriter(context) + taskAttemptContext = context + } + + override def write(row: Row): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(taskAttemptContext) +} + +private[sql] class FSBasedParquetRelation( + paths: Array[String], + private val maybeDataSchema: Option[StructType], + private val maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + val sqlContext: SQLContext) + extends FSBasedRelation(paths, maybePartitionSpec) + with Logging { + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters.getOrElse(FSBasedParquetRelation.MERGE_SCHEMA, "true").toBoolean + + private val maybeMetastoreSchema = parameters + .get(FSBasedParquetRelation.METASTORE_SCHEMA) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + private val metadataCache = new MetadataCache + metadataCache.refresh() + + override def equals(other: scala.Any): Boolean = other match { + case that: FSBasedParquetRelation => + val schemaEquality = if (shouldMergeSchemas) { + this.shouldMergeSchemas == that.shouldMergeSchemas + } else { + this.dataSchema == that.dataSchema && + this.schema == that.schema + } + + this.paths.toSet == that.paths.toSet && + schemaEquality && + this.maybeDataSchema == that.maybeDataSchema && + this.partitionColumns == that.partitionColumns + + case _ => false + } + + override def hashCode(): Int = { + if (shouldMergeSchemas) { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + maybeDataSchema, + maybePartitionSpec) + } else { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + dataSchema, + schema, + maybeDataSchema, + maybePartitionSpec) + } + } + + override def outputWriterClass: Class[_ <: OutputWriter] = classOf[ParquetOutputWriter] + + override def dataSchema: StructType = metadataCache.dataSchema + + override private[sql] def refresh(): Unit = { + metadataCache.refresh() + super.refresh() + } + + // Parquet data source always uses Catalyst internal representations. + override val needConversion: Boolean = false + + override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum + + override def prepareForWrite(job: Job): Unit = { + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + "spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], + classOf[ParquetOutputCommitter]) + + conf.setClass( + "mapred.output.committer.class", + committerClass, + classOf[ParquetOutputCommitter]) + + // TODO There's no need to use two kinds of WriteSupport + // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and + // complex types. + val writeSupportClass = + if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) + RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String]): RDD[Row] = { + + val job = Job.getInstance(SparkHadoopUtil.get.conf) + val conf = ContextUtil.getConfiguration(job) + + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + + if (inputPaths.nonEmpty) { + FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) + } + + // Try to push down filters when filter push-down is enabled. + if (sqlContext.conf.parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) + } + + conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + ParquetTypesConverter.convertToString(requestedSchema.toAttributes) + }) + + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(dataSchema.toAttributes)) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean + conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString) + + val inputFileStatuses = + metadataCache.dataStatuses.filter(f => inputPaths.contains(f.getPath.toString)) + + val footers = inputFileStatuses.map(metadataCache.footers) + + // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. + // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects and + // footers. Especially when a global arbitrative schema (either from metastore or data source + // DDL) is available. + new NewHadoopRDD( + sqlContext.sparkContext, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + conf) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFileStatuses.map { f => + // In order to encode the authority of a Path containing special characters such as /, + // we need to use the string returned by the URI of the path to create a new Path. + val pathWithAuthority = new Path(f.getPath.toUri.toString) + + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithAuthority) + }.toSeq + + @transient val cachedFooters = footers.map { f => + // In order to encode the authority of a Path containing special characters such as /, + // we need to use the string returned by the URI of the path to create a new Path. + new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) + }.toSeq + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses + + override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + } + } else { + new FilteringParquetRowInputFormat + } + + val jobContext = newJobContext(getConf, jobId) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + } + }.values + } + + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ + + // Parquet footer cache. + var footers: Map[FileStatus, Footer] = _ + + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var dataSchema: StructType = _ + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + val baseStatuses = paths.distinct.flatMap { p => + val path = new Path(p) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(fs.getFileStatus(qualified)).toOption + } + assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = baseStatuses.flatMap { f => + val fs = FileSystem.get(f.getPath.toUri, SparkHadoopUtil.get.conf) + SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } + } + + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => + val parquetMetadata = ParquetFileReader.readFooter( + SparkHadoopUtil.get.conf, f, ParquetMetadataConverter.NO_FILTER) + f -> new Footer(f.getPath, parquetMetadata) + }.seq.toMap + + dataSchema = { + val dataSchema0 = + maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(FSBasedParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + // + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + val filesToTouch = + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + assert( + filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, + "No schema defined, " + + s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") + + FSBasedParquetRelation.readSchema(filesToTouch.map(footers.apply), sqlContext) + } + } +} + +private[sql] object FSBasedParquetRelation extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + private[sql] val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + footers.map { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val parquetSchema = metadata.getSchema + val maybeSparkSchema = metadata + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .flatMap { serializedSchema => + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Try(DataType.fromJson(serializedSchema)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .toOption + } + + maybeSparkSchema.getOrElse { + // Falls back to Parquet schema if Spark SQL schema is absent. + StructType.fromAttributes( + // TODO Really no need to use `Attribute` here, we only need to know the data type. + ParquetTypesConverter.convertToAttributes( + parquetSchema, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) + } + }.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala deleted file mode 100644 index ee4b1c72a2148..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ /dev/null @@ -1,840 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} -import java.math.{BigDecimal => JBigDecimal} -import java.net.URI -import java.text.SimpleDateFormat -import java.util.{Date, List => JList} - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} -import parquet.filter2.predicate.FilterApi -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetInputFormat, _} - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} -import org.apache.spark.sql.parquet.ParquetTypesConverter._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} - -/** - * Allows creation of Parquet based tables using the syntax: - * {{{ - * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) - * }}} - * - * Supported options include: - * - * - `path`: Required. When reading Parquet files, `path` should point to the location of the - * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. - * In the latter case, this data source tries to discover partitioning information if the the - * directory is structured in the same style of Hive partitioned tables. When writing Parquet - * file, `path` should point to the destination folder. - * - * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but - * compatible) schemas stored in all Parquet part-files. - * - * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is - * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration - * in Hive. - */ -private[sql] class DefaultSource - extends RelationProvider - with SchemaRelationProvider - with CreatableRelationProvider { - - private def checkPath(parameters: Map[String, String]): String = { - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - } - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) - } - - /** Returns a new base relation with the given parameters and schema. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { - ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) - } - - /** Returns a new base relation with the given parameters and save given data into it. */ - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val path = checkPath(parameters) - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doInsertion = (mode, fs.exists(filesystemPath)) match { - case (SaveMode.ErrorIfExists, true) => - sys.error(s"path $path already exists.") - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - } - - val relation = if (doInsertion) { - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val df = - sqlContext.createDataFrame( - data.queryExecution.toRdd, - data.schema.asNullable, - needsConversion = false) - val createdRelation = - createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] - createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) - createdRelation - } else { - // If the save mode is Ignore, we will just create the relation based on existing data. - createRelation(sqlContext, parameters) - } - - relation - } -} - -/** - * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will - * be deprecated and eventually removed once this version is proved to be stable enough. - * - * Compared with the old implementation, this class has the following notable differences: - * - * - Partitioning discovery: Hive style multi-level partitions are auto discovered. - * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source - * can detect and merge schemas from all Parquet part-files as long as they are compatible. - * Also, metadata and [[FileStatus]]es are cached for better performance. - * - Statistics: Statistics for the size of the table are automatically populated during schema - * discovery. - */ -@DeveloperApi -private[sql] case class ParquetRelation2( - paths: Seq[String], - parameters: Map[String, String], - maybeSchema: Option[StructType] = None, - maybePartitionSpec: Option[PartitionSpec] = None)( - @transient val sqlContext: SQLContext) - extends BaseRelation - with CatalystScan - with InsertableRelation - with SparkHadoopMapReduceUtil - with Logging { - - // Should we merge schemas from all Parquet part-files? - private val shouldMergeSchemas = - parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean - - // Optional Metastore schema, used when converting Hive Metastore Parquet table - private val maybeMetastoreSchema = - parameters - .get(ParquetRelation2.METASTORE_SCHEMA) - .map(s => DataType.fromJson(s).asInstanceOf[StructType]) - - // Hive uses this as part of the default partition name when the partition column value is null - // or empty string - private val defaultPartitionName = parameters.getOrElse( - ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") - - override def equals(other: Any): Boolean = other match { - case relation: ParquetRelation2 => - // If schema merging is required, we don't compare the actual schemas since they may evolve. - val schemaEquality = if (shouldMergeSchemas) { - shouldMergeSchemas == relation.shouldMergeSchemas - } else { - schema == relation.schema - } - - paths.toSet == relation.paths.toSet && - schemaEquality && - maybeMetastoreSchema == relation.maybeMetastoreSchema && - maybePartitionSpec == relation.maybePartitionSpec - - case _ => false - } - - override def hashCode(): Int = { - if (shouldMergeSchemas) { - com.google.common.base.Objects.hashCode( - shouldMergeSchemas: java.lang.Boolean, - paths.toSet, - maybeMetastoreSchema, - maybePartitionSpec) - } else { - com.google.common.base.Objects.hashCode( - shouldMergeSchemas: java.lang.Boolean, - schema, - paths.toSet, - maybeMetastoreSchema, - maybePartitionSpec) - } - } - - private[sql] def sparkContext = sqlContext.sparkContext - - private class MetadataCache { - // `FileStatus` objects of all "_metadata" files. - private var metadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all "_common_metadata" files. - private var commonMetadataStatuses: Array[FileStatus] = _ - - // Parquet footer cache. - var footers: Map[FileStatus, Footer] = _ - - // `FileStatus` objects of all data files (Parquet part-files). - var dataStatuses: Array[FileStatus] = _ - - // Partition spec of this table, including names, data types, and values of each partition - // column, and paths of each partition. - var partitionSpec: PartitionSpec = _ - - // Schema of the actual Parquet files, without partition columns discovered from partition - // directory paths. - var parquetSchema: StructType = _ - - // Schema of the whole table, including partition columns. - var schema: StructType = _ - - // Indicates whether partition columns are also included in Parquet data file schema. If not, - // we need to fill in partition column values into read rows when scanning the table. - var partitionKeysIncludedInParquetSchema: Boolean = _ - - def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) - - ParquetRelation.enableLogForwarding() - ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) - } - - /** - * Refreshes `FileStatus`es, footers, partition spec, and table schema. - */ - def refresh(): Unit = { - // Support either reading a collection of raw Parquet part-files, or a collection of folders - // containing Parquet files (e.g. partitioned Parquet table). - val baseStatuses = paths.distinct.map { p => - val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val path = new Path(p) - val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) - - if (!fs.exists(qualified) && maybeSchema.isDefined) { - fs.mkdirs(qualified) - prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) - } - - fs.getFileStatus(qualified) - }.toArray - assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) - - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = baseStatuses.flatMap { f => - val fs = FileSystem.get(f.getPath.toUri, sparkContext.hadoopConfiguration) - SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - } - } - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => - val parquetMetadata = ParquetFileReader.readFooter( - sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) - f -> new Footer(f.getPath, parquetMetadata) - }.seq.toMap - - partitionSpec = maybePartitionSpec.getOrElse { - val partitionDirs = leaves - .filterNot(baseStatuses.contains) - .map(_.getPath.getParent) - .distinct - - if (partitionDirs.nonEmpty) { - // Parses names and values of partition columns, and infer their data types. - PartitioningUtils.parsePartitions(partitionDirs, defaultPartitionName) - } else { - // No partition directories found, makes an empty specification - PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) - } - } - - // To get the schema. We first try to get the schema defined in maybeSchema. - // If maybeSchema is not defined, we will try to get the schema from existing parquet data - // (through readSchema). If data does not exist, we will try to get the schema defined in - // maybeMetastoreSchema (defined in the options of the data source). - // Finally, if we still could not get the schema. We throw an error. - parquetSchema = - maybeSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) - - partitionKeysIncludedInParquetSchema = - isPartitioned && - partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) - - schema = { - val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { - parquetSchema - } else { - StructType(parquetSchema.fields ++ partitionColumns.fields) - } - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // insensitivity issue and possible schema mismatch. - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) - .getOrElse(fullRelationSchema) - } - } - - private def readSchema(): Option[StructType] = { - // Sees which file(s) we need to touch in order to figure out the schema. - val filesToTouch = - // Always tries the summary files first if users don't require a merged schema. In this case, - // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row - // groups information, and could be much smaller for large Parquet files with lots of row - // groups. - // - // NOTE: Metadata stored in the summary files are merged from all part-files. However, for - // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know - // how to merge them correctly if some key is associated with different values in different - // part-files. When this happens, Parquet simply gives up generating the summary file. This - // implies that if a summary file presents, then: - // - // 1. Either all part-files have exactly the same Spark SQL schema, or - // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus - // their schemas may differ from each other). - // - // Here we tend to be pessimistic and take the second case into account. Basically this means - // we can't trust the summary files if users require a merged schema, and must touch all part- - // files to do the merge. - if (shouldMergeSchemas) { - // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq - } else { - // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet - // don't have this. - commonMetadataStatuses.headOption - // Falls back to "_metadata" - .orElse(metadataStatuses.headOption) - // Summary file(s) not found, the Parquet file is either corrupted, or different part- - // files contain conflicting user defined metadata (two or more values are associated - // with a same key in different files). In either case, we fall back to any of the - // first part-file, and just assume all schemas are consistent. - .orElse(dataStatuses.headOption) - .toSeq - } - - ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) - } - } - - @transient private val metadataCache = new MetadataCache - metadataCache.refresh() - - def partitionSpec: PartitionSpec = metadataCache.partitionSpec - - def partitionColumns: StructType = metadataCache.partitionSpec.partitionColumns - - def partitions: Seq[Partition] = metadataCache.partitionSpec.partitions - - def isPartitioned: Boolean = partitionColumns.nonEmpty - - private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema - - private def parquetSchema = metadataCache.parquetSchema - - override def schema: StructType = metadataCache.schema - - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } - - // Skip type conversion - override val needConversion: Boolean = false - - // TODO Should calculate per scan size - // It's common that a query only scans a fraction of a large Parquet file. Returning size of the - // whole Parquet file disables some optimizations in this case (e.g. broadcast join). - override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum - - // This is mostly a hack so that we can use the existing parquet filter code. - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) - - val selectedPartitions = prunePartitions(predicates, partitions) - val selectedFiles = if (isPartitioned) { - selectedPartitions.flatMap { p => - metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) - } - } else { - metadataCache.dataStatuses.toSeq - } - val selectedFooters = selectedFiles.map(metadataCache.footers) - - // FileInputFormat cannot handle empty lists. - if (selectedFiles.nonEmpty) { - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) - FileInputFormat.setInputPaths(job, selectedPaths: _*) - } - - // Try to push down filters when filter push-down is enabled. - if (sqlContext.conf.parquetFilterPushDown) { - val partitionColNames = partitionColumns.map(_.name).toSet - predicates - // Don't push down predicates which reference partition columns - .filter { pred => - val referencedColNames = pred.references.map(_.name).toSet - referencedColNames.intersect(partitionColNames).isEmpty - } - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) - } - - if (isPartitioned) { - logInfo { - val percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - s"Reading $percentRead% of partitions" - } - } - - val requiredColumns = output.map(_.name) - val requestedSchema = StructType(requiredColumns.map(schema(_))) - - // Store both requested and original schema in `Configuration` - jobConf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - convertToString(requestedSchema.toAttributes)) - jobConf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - convertToString(schema.toAttributes)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean - jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) - - val baseRDD = - new NewHadoopRDD( - sparkContext, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row], - jobConf) { - val cacheMetadata = useCache - - @transient - val cachedStatus = selectedFiles.map { st => - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - val newPath = new Path(st.getPath.toUri.toString) - - new FileStatus( - st.getLen, - st.isDir, - st.getReplication, - st.getBlockSize, - st.getModificationTime, - st.getAccessTime, - st.getPermission, - st.getOwner, - st.getGroup, - newPath) - } - - @transient - val cachedFooters = selectedFooters.map { f => - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) - } - - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - - override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters - } - } else { - new FilteringParquetRowInputFormat - } - - val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - } - } - - // The ordinals for partition keys in the result row, if requested. - val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { - case (name, index) => index -> requiredColumns.indexOf(name) - }.toMap.filter { - case (_, index) => index >= 0 - } - - // When the data does not include the key and the key is requested then we must fill it in - // based on information from the input split. - if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { - // This check is based on CatalystConverter.createRootConverter. - val primitiveRow = - requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) - - baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => - val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => - CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] - }.get - - val requiredPartOrdinal = partitionKeyLocations.keys.toSeq - - if (primitiveRow) { - iterator.map { pair => - // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. - val row = pair._2.asInstanceOf[SpecificMutableRow] - var i = 0 - while (i < requiredPartOrdinal.size) { - // TODO Avoids boxing cost here! - val partOrdinal = requiredPartOrdinal(i) - row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) - i += 1 - } - row - } - } else { - // Create a mutable row since we need to fill in values from partition columns. - val mutableRow = new GenericMutableRow(requestedSchema.size) - iterator.map { pair => - // We are using CatalystGroupConverter and it returns a GenericRow. - // Since GenericRow is not mutable, we just cast it to a Row. - val row = pair._2.asInstanceOf[Row] - var i = 0 - while (i < row.size) { - // TODO Avoids boxing cost here! - mutableRow(i) = row(i) - i += 1 - } - - i = 0 - while (i < requiredPartOrdinal.size) { - // TODO Avoids boxing cost here! - val partOrdinal = requiredPartOrdinal(i) - mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) - i += 1 - } - mutableRow - } - } - } - } else { - baseRDD.map(_._2) - } - } - - private def prunePartitions( - predicates: Seq[Expression], - partitions: Seq[Partition]): Seq[Partition] = { - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - val rawPredicate = - partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) - val boundPredicate = InterpretedPredicate.create(rawPredicate transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - if (isPartitioned && partitionPruningPredicates.nonEmpty) { - partitions.filter(p => boundPredicate(p.values)) - } else { - partitions - } - } - - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") - - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = - if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(data.schema.toAttributes, conf) - - val destinationPath = new Path(paths.head) - - if (overwrite) { - val fs = destinationPath.getFileSystem(conf) - if (fs.exists(destinationPath)) { - var success: Boolean = false - try { - success = fs.delete(destinationPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${destinationPath.toString} prior" + - s" to writing to Parquet table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${destinationPath.toString} prior" + - s" to writing to Parquet table.") - } - } - } - - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[Row]) - FileOutputFormat.setOutputPath(job, destinationPath) - - val wrappedConf = new SerializableWritable(job.getConfiguration) - val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = if (overwrite) { - 1 - } else { - FileSystemHelper.findMaxTaskId( - FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { - /* "reduce task" */ - val attemptId = newTaskAttemptID( - jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iterator.hasNext) { - val row = iterator.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - - SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) - jobCommitter.commitJob(jobTaskContext) - - metadataCache.refresh() - } -} - -private[sql] object ParquetRelation2 extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - val MERGE_SCHEMA = "mergeSchema" - - // Default partition name to use when the partition column value is null or empty string. - val DEFAULT_PARTITION_NAME = "partition.defaultName" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - - private[parquet] def readSchema( - footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { - footers.map { footer => - val metadata = footer.getParquetMetadata.getFileMetaData - val parquetSchema = metadata.getSchema - val maybeSparkSchema = metadata - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - .flatMap { serializedSchema => - // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to - // whatever is available. - Try(DataType.fromJson(serializedSchema)) - .recover { case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema) - } - .recover { case cause: Throwable => - logWarning( - s"""Failed to parse serialized Spark schema in Parquet key-value metadata: - |\t$serializedSchema - """.stripMargin, - cause) - } - .map(_.asInstanceOf[StructType]) - .toOption - } - - maybeSparkSchema.getOrElse { - // Falls back to Parquet schema if Spark SQL schema is absent. - StructType.fromAttributes( - // TODO Really no need to use `Attribute` here, we only need to know the data type. - convertToAttributes( - parquetSchema, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) - } - }.reduceOption { (left, right) => - try left.merge(right) catch { case e: Throwable => - throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) - } - } - } - - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - private[parquet] def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index a294297677d1a..7879328bbaaab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -293,9 +293,18 @@ private[sql] abstract class BaseWriterContainer( } private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - outputFormatClass.newInstance().getOutputCommitter(context) + val committerClass = context.getConfiguration.getClass( + "mapred.output.committer.class", null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + }.getOrElse { + outputFormatClass.newInstance().getOutputCommitter(context) + } } + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) @@ -345,6 +354,7 @@ private[sql] class DefaultWriterContainer( override protected def initWriters(): Unit = { writer = outputWriterClass.newInstance() + taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) writer.init(getWorkPath, dataSchema, taskAttemptContext) } @@ -384,11 +394,14 @@ private[sql] class DynamicPartitionWriterContainer( DynamicPartitionWriterContainer.escapePathName(string) } s"/$col=$valueString" - }.mkString + }.mkString.stripPrefix(Path.SEPARATOR) outputWriters.getOrElseUpdate(partitionPath, { - val path = new Path(getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR)) + val path = new Path(getWorkPath, partitionPath) val writer = outputWriterClass.newInstance() + taskAttemptContext.getConfiguration.set( + "spark.sql.sources.output.path", + new Path(outputPath, partitionPath).toString) writer.init(path.toString, dataSchema, taskAttemptContext) writer }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 10d0ede4dc0dc..3bbc5b05868af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -63,7 +63,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { }.flatten.reduceOption(_ && _) val forParquetDataSource = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters + case PhysicalOperation(_, filters, LogicalRelation(_: FSBasedParquetRelation)) => filters }.flatten.reduceOption(_ && _) forParquetTableScan.orElse(forParquetDataSource) @@ -350,7 +350,7 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } - + test("SPARK-6742: don't push down predicates which reference partition columns") { import sqlContext.implicits._ @@ -365,7 +365,7 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, Seq(AttributeReference("part", IntegerType, false)()) )) - + checkAnswer( df.filter("a = 1 or part = 1"), (1 to 3).map(i => Row(1, i, i.toString))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b504842053690..7c371dbc7d3c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -119,7 +119,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { + intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) parquetFile(dir.getCanonicalPath).collect() @@ -127,7 +127,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { + intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) parquetFile(dir.getCanonicalPath).collect() @@ -419,7 +419,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// - intercept[java.io.FileNotFoundException] { + intercept[Throwable] { sqlContext.parquetFile("file:///nonexistent") } val errorMessage = intercept[Throwable] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index bea568ed40049..138e19766dc88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -39,7 +39,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { import sqlContext._ import sqlContext.implicits._ - val defaultPartitionName = "__NULL__" + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" test("column type inference") { def check(raw: String, literal: Literal): Unit = { @@ -252,9 +252,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { val parquetRelation = load( "org.apache.spark.sql.parquet", - Map( - "path" -> base.getCanonicalPath, - ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + Map("path" -> base.getCanonicalPath)) parquetRelation.registerTempTable("t") @@ -297,9 +295,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { val parquetRelation = load( "org.apache.spark.sql.parquet", - Map( - "path" -> base.getCanonicalPath, - ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + Map("path" -> base.getCanonicalPath)) parquetRelation.registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d984557..fc90e3edce7fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -204,7 +204,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -219,7 +219,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -230,7 +230,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -241,7 +241,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Conflicting non-nullable field names intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -255,7 +255,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -268,7 +268,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + FSBasedParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d754c8e3a8aa1..b0e82c8d033b2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,10 +33,10 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.FSBasedParquetRelation import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} import org.apache.spark.util.Utils /* Implicit conversions */ @@ -226,8 +226,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + FSBasedParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, + FSBasedParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -238,13 +238,15 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + case logical@LogicalRelation(parquetRelation: FSBasedParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && - parquetRelation.maybePartitionSpec == partitionSpecInMetastore + parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { + PartitionSpec(StructType(Nil), Array.empty[sources.Partition]) + } if (useCached) { Some(logical) @@ -256,7 +258,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case other => logWarning( s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"as Parquet. However, we are getting a $other from the metastore cache. " + s"This cached entry will be invalidated.") cachedDataSourceTables.invalidate(tableIdentifier) None @@ -278,8 +280,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { - val created = - LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + val created = LogicalRelation( + new FSBasedParquetRelation( + paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -290,8 +293,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { - val created = - LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + val created = LogicalRelation( + new FSBasedParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 47c60f651d14c..da5d203d9d343 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,21 +21,18 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterEach - import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException +import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql._ -import org.apache.spark.util.Utils -import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.FSBasedParquetRelation import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. @@ -582,11 +579,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK + case LogicalRelation(p: FSBasedParquetRelation) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + s"${classOf[FSBasedParquetRelation].getCanonicalName}") } // Clenup and reset confs. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a5744ccc68a47..1d6393a3fec85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.DefaultParserDialect -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim} -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} +import org.apache.spark.sql.parquet.FSBasedParquetRelation import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, DefaultParserDialect, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -176,17 +174,17 @@ class SQLQuerySuite extends QueryTest { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { - case LogicalRelation(r: ParquetRelation2) => + case LogicalRelation(r: FSBasedParquetRelation) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${ParquetRelation2.getClass.getCanonicalName}.") + s"${FSBasedParquetRelation.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${FSBasedParquetRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } @@ -596,7 +594,7 @@ class SQLQuerySuite extends QueryTest { sql(s"DROP TABLE $tableName") } } - + test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index bf1121ddf0273..41bcbe84b0ef2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,16 +21,15 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{QueryTest, SQLConf} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} -import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.parquet.{FSBasedParquetRelation, ParquetTableScan} +import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoFSBasedRelation, LogicalRelation} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -292,10 +291,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}") + case LogicalRelation(_: FSBasedParquetRelation) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[FSBasedParquetRelation].getCanonicalName}") } sql("DROP TABLE IF EXISTS test_parquet_ctas") @@ -316,12 +315,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand( - InsertIntoDataSource( - LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case ExecutedCommand(InsertIntoFSBasedRelation(_: FSBasedParquetRelation, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"${classOf[FSBasedParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -348,11 +345,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand( - InsertIntoDataSource( - LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case ExecutedCommand(InsertIntoFSBasedRelation(r: FSBasedParquetRelation, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[FSBasedParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + s"However, found a ${o.toString} ") } @@ -383,7 +378,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation2) => r + case r @ LogicalRelation(_: FSBasedParquetRelation) => r }.size } @@ -395,7 +390,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case logical @ LogicalRelation(parquetRelation: FSBasedParquetRelation) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -693,7 +688,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + intercept[Throwable](df2.saveAsParquetFile(filePath)) val df3 = df2.toDF("str", "max_int") df3.saveAsParquetFile(filePath2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala similarity index 83% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala index e8b48a0db1c79..394833f22907d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala @@ -28,12 +28,14 @@ import org.apache.spark.sql.types._ // TODO Don't extend ParquetTest // This test suite extends ParquetTest for some convenient utility methods. These methods should be // moved to some more general places, maybe QueryTest. -class FSBasedRelationSuite extends QueryTest with ParquetTest { +class FSBasedRelationTest extends QueryTest with ParquetTest { override val sqlContext: SQLContext = TestHive import sqlContext._ import sqlContext.implicits._ + val dataSourceName = classOf[SimpleTextSource].getCanonicalName + val dataSchema = StructType( Seq( @@ -92,17 +94,17 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempPath { file => testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite) testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite) checkAnswer( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json)), @@ -114,17 +116,17 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempPath { file => testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite) testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append) checkAnswer( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json)).orderBy("a"), @@ -137,7 +139,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { intercept[RuntimeException] { testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.ErrorIfExists) } } @@ -147,7 +149,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempDir { file => testDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Ignore) val path = new Path(file.getCanonicalPath) @@ -159,62 +161,37 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("save()/load() - partitioned table - simple queries") { withTempPath { file => partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.ErrorIfExists, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) checkQueries( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json))) } } - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) - } - } - test("save()/load() - partitioned table - Overwrite") { withTempPath { file => partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) checkAnswer( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json)), @@ -225,20 +202,20 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("save()/load() - partitioned table - Append") { withTempPath { file => partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) checkAnswer( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json)), @@ -249,20 +226,20 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("save()/load() - partitioned table - Append - new partition values") { withTempPath { file => partitionedTestDF1.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) partitionedTestDF2.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) checkAnswer( load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchema.json)), @@ -274,7 +251,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempDir { file => intercept[RuntimeException] { partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.ErrorIfExists, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) @@ -286,7 +263,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempDir { file => partitionedTestDF.save( path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Ignore) val path = new Path(file.getCanonicalPath) @@ -302,7 +279,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - non-partitioned table - Overwrite") { testDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, Map("dataSchema" -> dataSchema.json)) @@ -314,12 +291,12 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - non-partitioned table - Append") { testDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite) testDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append) withTable("t") { @@ -334,7 +311,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { intercept[AnalysisException] { testDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.ErrorIfExists) } } @@ -346,7 +323,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempTable("t") { testDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Ignore) assert(table("t").collect().isEmpty) @@ -356,7 +333,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - partitioned table - simple queries") { partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, Map("dataSchema" -> dataSchema.json)) @@ -368,14 +345,14 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -388,14 +365,14 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - partitioned table - Append") { partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -408,14 +385,14 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - partitioned table - Append - new partition values") { partitionedTestDF1.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) partitionedTestDF2.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -428,7 +405,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { partitionedTestDF1.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -437,7 +414,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { intercept[Throwable] { partitionedTestDF2.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1")) @@ -447,7 +424,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { intercept[Throwable] { partitionedTestDF2.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Append, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p2", "p1")) @@ -461,7 +438,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { intercept[AnalysisException] { partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.ErrorIfExists, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -475,7 +452,7 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { withTempTable("t") { partitionedTestDF.saveAsTable( tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Ignore, options = Map("dataSchema" -> dataSchema.json), partitionColumns = Seq("p1", "p2")) @@ -487,13 +464,13 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { test("Hadoop style globbing") { withTempPath { file => partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, mode = SaveMode.Overwrite, options = Map("path" -> file.getCanonicalPath), partitionColumns = Seq("p1", "p2")) val df = load( - source = classOf[SimpleTextSource].getCanonicalName, + source = dataSourceName, options = Map( "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", "dataSchema" -> dataSchema.json)) @@ -521,3 +498,67 @@ class FSBasedRelationSuite extends QueryTest with ParquetTest { } } } + +class SimpleTextRelationSuite extends FSBasedRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} + +class FSBasedParquetRelationSuite extends FSBasedRelationTest { + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .saveAsParquetFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} From 213a6f30fee4a1c416ea76b678c71877fd36ef18 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 May 2015 12:47:48 -0700 Subject: [PATCH 011/109] [SPARK-7551][DataFrame] support backticks for DataFrame attribute resolution Author: Wenchen Fan Closes #6074 from cloud-fan/7551 and squashes the following commits: e6f579e [Wenchen Fan] allow space 2b86699 [Wenchen Fan] handle blank e218d99 [Wenchen Fan] address comments 54c4209 [Wenchen Fan] fix 7551 --- .../catalyst/plans/logical/LogicalPlan.scala | 55 ++++++++++++++++++- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 27 +++++++++ 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dbb12d56f9497..dba69659afc80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -105,7 +105,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } /** - * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ @@ -116,7 +116,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** - * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ @@ -126,6 +126,57 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { throwErrors: Boolean = false): Option[NamedExpression] = resolve(nameParts, output, resolver, throwErrors) + /** + * Given an attribute name, split it to name parts by dot, but + * don't split the name parts quoted by backticks, for example, + * `ab.cd`.`efg` should be split into two parts "ab.cd" and "efg". + */ + def resolveQuoted( + name: String, + resolver: Resolver): Option[NamedExpression] = { + resolve(parseAttributeName(name), resolver, true) + } + + /** + * Internal method, used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * Escape character is not supported now, so we can't use backtick inside name part. + */ + private def parseAttributeName(name: String): Seq[String] = { + val e = new AnalysisException(s"syntax error in attribute name: $name") + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (tmp.isEmpty) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (tmp.isEmpty || inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } + /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c820a673575ff..4fd5105c27443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,7 +160,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -168,7 +168,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52aa1f6558f80..1d5f6b3aad6fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,6 +459,33 @@ class DataFrameSuite extends QueryTest { assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) } + test("SPARK-7551: support backticks for DataFrame attribute resolution") { + val df = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) + checkAnswer( + df.select(df("`a.b`.c.`d..e`.`f`")), + Row(1) + ) + + val df2 = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) + checkAnswer( + df2.select(df2("`a b`.c.d e.f")), + Row(1) + ) + + def checkError(testFun: => Unit): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + testFun + } + assert(e.getMessage.contains("syntax error in attribute name:")) + } + checkError(df("`abc.`c`")) + checkError(df("`abc`..d")) + checkError(df("`a`.b.")) + checkError(df("`a.b`.c.`d")) + } + test("SPARK-7324 dropDuplicates") { val testData = TestSQLContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: From e676fc0c6326f3ddeced87214cc88534ea646473 Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Wed, 13 May 2015 21:00:12 +0100 Subject: [PATCH 012/109] [MINOR] Avoid passing the PermGenSize option to IBM JVMs. IBM's Java VM doesn't have the concept of a permgen, so this option shouldn't be passed when the vendor property shows it is an IBM JDK. Author: Tim Ellison Author: Tim Ellison Closes #6055 from tellison/MaxPermSize and squashes the following commits: 3a0fb66 [Tim Ellison] Convert tabs back to spaces 6ad4266 [Tim Ellison] Remove unnecessary else clauses to reduce nesting. d27174b [Tim Ellison] Merge branch 'master' of https://github.com/apache/spark into MaxPermSize 42a8c3f [Tim Ellison] [MINOR] Avoid passing the PermGenSize option to IBM JVMs. --- .../launcher/AbstractCommandBuilder.java | 5 ++++- .../spark/launcher/CommandBuilderUtils.java | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index b8f02b961113d..33fd813f7a86c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -121,7 +121,10 @@ List buildJavaCommand(String extraClassPath) throws IOException { * set it. */ void addPermGenSizeOpt(List cmd) { - // Don't set MaxPermSize for Java 8 and later. + // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. + if (getJavaVendor() == JavaVendor.IBM) { + return; + } String[] version = System.getProperty("java.version").split("\\."); if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { return; diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 261402856ac5e..2665a700fe1f5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -32,6 +32,11 @@ class CommandBuilderUtils { static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; + /** The set of known JVM vendors. */ + static enum JavaVendor { + Oracle, IBM, OpenJDK, Unknown + }; + /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { return s == null || s.isEmpty(); @@ -108,6 +113,21 @@ static boolean isWindows() { return os.startsWith("Windows"); } + /** Returns an enum value indicating whose JVM is being used. */ + static JavaVendor getJavaVendor() { + String vendorString = System.getProperty("java.vendor"); + if (vendorString.contains("Oracle")) { + return JavaVendor.Oracle; + } + if (vendorString.contains("IBM")) { + return JavaVendor.IBM; + } + if (vendorString.contains("OpenJDK")) { + return JavaVendor.OpenJDK; + } + return JavaVendor.Unknown; + } + /** * Updates the user environment, appending the given pathList to the existing value of the given * environment variable (or setting it if it hasn't yet been set). From 3cd9ad2406c59cd0ede6c9c8428a4ce4b805f8fa Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Wed, 13 May 2015 21:01:42 +0100 Subject: [PATCH 013/109] =?UTF-8?q?[MINOR]=20Enhance=20SizeEstimator=20to?= =?UTF-8?q?=20detect=20IBM=20compressed=20refs=20and=20s390=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …arch. - zSeries 64-bit Java reports its architecture as s390x, so enhance the 64-bit check to accommodate that value. - SizeEstimator can detect whether IBM Java is using compressed object pointers using info in the "java.vm.info" property, so will do a better job than failing on the HotSpot MBean and guessing. Author: Tim Ellison Closes #6085 from tellison/SizeEstimator and squashes the following commits: 1b6ff6a [Tim Ellison] Merge branch 'master' of https://github.com/apache/spark into SizeEstimator 0968989 [Tim Ellison] [MINOR] Enhance SizeEstimator to detect IBM compressed refs and s390 arch. --- .../scala/org/apache/spark/util/SizeEstimator.scala | 8 +++++++- .../org/apache/spark/util/SizeEstimatorSuite.scala | 12 ++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index d91c3294ddb8b..968a72d5adae9 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -75,7 +75,8 @@ private[spark] object SizeEstimator extends Logging { // Sets object size, pointer size based on architecture and CompressedOops settings // from the JVM. private def initialize() { - is64bit = System.getProperty("os.arch").contains("64") + val arch = System.getProperty("os.arch") + is64bit = arch.contains("64") || arch.contains("s390x") isCompressedOops = getIsCompressedOops objectSize = if (!is64bit) 8 else { @@ -97,6 +98,11 @@ private[spark] object SizeEstimator extends Logging { return System.getProperty("spark.test.useCompressedOops").toBoolean } + // java.vm.info provides compressed ref info for IBM JDKs + if (System.getProperty("java.vendor").contains("IBM")) { + return System.getProperty("java.vm.info").contains("Compressed Ref") + } + try { val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val server = ManagementFactory.getPlatformMBeanServer() diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 133a76f28e000..04f0f3749d6b9 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -45,6 +45,10 @@ class DummyClass6 extends DummyClass5 { val y: Boolean = true } +class DummyClass7 { + val x: DummyClass1 = new DummyClass1 +} + object DummyString { def apply(str: String) : DummyString = new DummyString(str.toArray) } @@ -197,4 +201,12 @@ class SizeEstimatorSuite assertResult(24)(SizeEstimator.estimate(new DummyClass5)) assertResult(32)(SizeEstimator.estimate(new DummyClass6)) } + + test("check 64-bit detection for s390x arch") { + System.setProperty("os.arch", "s390x") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + // Class should be 32 bytes on s390x if recognised as 64 bit platform + assertResult(32)(SizeEstimator.estimate(new DummyClass7)) + } } From 51030b8a9d4f3feb7a5d2249cc867fd6a06f0336 Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Wed, 13 May 2015 21:16:32 +0100 Subject: [PATCH 014/109] [MINOR] [CORE] Accept alternative mesos unsatisfied link error in test. The IBM JVM reports an failed library load with a slightly different error message to Oracle's JVM. Update the test case to allow for either form. Author: Tim Ellison Author: Tim Ellison Closes #6119 from tellison/LibraryLoading and squashes the following commits: 2c5cd4e [Tim Ellison] Reduce assertion to check for the mesos library name f48c194 [Tim Ellison] Split long line b1079d7 [Tim Ellison] [MINOR] [CORE] Accept alternative mesos unsatisfied link error in test. --- .../org/apache/spark/SparkContextSchedulerCreationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index bbed8ddc6bafc..9343f4fff89da 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -159,7 +159,7 @@ class SparkContextSchedulerCreationSuite assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => - assert(e.getMessage.contains("no mesos in")) + assert(e.getMessage.contains("mesos")) logWarning("Mesos not available, could not test actual Mesos scheduler creation") case e: Throwable => fail(e) } From 5db18ba6e1bd8c6307c41549176c53590cf344a0 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 13 May 2015 13:21:36 -0700 Subject: [PATCH 015/109] [SPARK-7593] [ML] Python Api for ml.feature.Bucketizer Added `ml.feature.Bucketizer` to PySpark. cc mengxr Author: Burak Yavuz Closes #6124 from brkyvz/ml-bucket and squashes the following commits: 05285be [Burak Yavuz] added sphinx doc 6abb6ed [Burak Yavuz] added support for Bucketizer --- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 15 +++- python/pyspark/ml/feature.py | 77 +++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index b28c88aaaecbc..e52d797293cf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -48,7 +48,7 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) * otherwise, values outside the splits specified will be treated as errors. * @group param */ - val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", + val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits", "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + "bucket, which also includes y. The splits should be strictly increasing. " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 7ebbf106ee753..5a7ec29aac6cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -219,7 +219,7 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } -/** Specialized version of [[Param[Array[T]]]] for Java. */ +/** Specialized version of [[Param[Array[String]]]] for Java. */ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) extends Param[Array[String]](parent, name, doc, isValid) { @@ -232,6 +232,19 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } +/** Specialized version of [[Param[Array[Double]]]] for Java. */ +class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) + extends Param[Array[Double]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) +} + /** * A param amd its value. */ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f35bc1463d51b..30e1fd4922d0a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -83,6 +83,83 @@ def getThreshold(self): return self.getOrDefault(self.threshold) +@inherit_doc +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.transform(df).collect() + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 + >>> bucketizer.setParams(outputCol="b").transform(df).head().b + 0.0 + """ + + _java_class = "org.apache.spark.ml.feature.Bucketizer" + # a placeholder to make it appear in the generated doc + splits = \ + Param(Params._dummy(), "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + + @keyword_only + def __init__(self, splits=None, inputCol=None, outputCol=None): + """ + __init__(self, splits=None, inputCol=None, outputCol=None) + """ + super(Bucketizer, self).__init__() + #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, + # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) + # except the last bucket, which also includes y. The splits should be strictly increasing. + # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, + # values outside the splits specified will be treated as errors. + self.splits = \ + Param(self, "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, splits=None, inputCol=None, outputCol=None): + """ + setParams(self, splits=None, inputCol=None, outputCol=None) + Sets params for this Bucketizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setSplits(self, value): + """ + Sets the value of :py:attr:`splits`. + """ + self.paramMap[self.splits] = value + return self + + def getSplits(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.splits) + + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ From 61e05fc58e1245de871c409b60951745b5db3420 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Wed, 13 May 2015 14:13:19 -0700 Subject: [PATCH 016/109] [SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1 Author: leahmcguire Closes #6073 from leahmcguire/binaryCheckNB and squashes the following commits: b8442c2 [leahmcguire] changed to if else for value checks 911bf83 [leahmcguire] undid reformat 4eedf1e [leahmcguire] moved bernoulli check 9ee9e84 [leahmcguire] fixed style error 3f3b32c [leahmcguire] fixed zero one check so only called in combiner 831fd27 [leahmcguire] got test working f44bb3c [leahmcguire] removed changes from CV branch 67253f0 [leahmcguire] added check to bernoulli to ensure feature values are zero or one f191c71 [leahmcguire] fixed name 58d060b [leahmcguire] changed param name and test according to comments 04f0d3c [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access --- .../mllib/classification/NaiveBayes.scala | 28 ++++++++++++++-- .../classification/NaiveBayesSuite.scala | 33 +++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c9b3ff0172e2e..b381dc2cb0140 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { + val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + labels (brzArgmax (brzPi + brzTheta * brzData) ) case "Bernoulli" => + if (!brzData.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + } labels (brzArgmax (brzPi + - (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") @@ -293,12 +298,29 @@ class NaiveBayes private ( } } + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) + if (modelType == "Bernoulli") { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea89b17b7c08f..40a79a1f19bd9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("detect non zero or one values in Bernoulli") { + val badTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") + } + + val okTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)) + ) + + val badPredict = Seq( + Vectors.dense(1.0), + Vectors.dense(2.0), + Vectors.dense(1.0), + Vectors.dense(0.0)) + + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") + intercept[SparkException] { + model.predict(sc.makeRDD(badPredict, 2)).collect() + } + } + test("model save/load: 2.0 to 2.0") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString From df2fb1305aba6781017b0973b0965b664f835e31 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 13 May 2015 15:13:09 -0700 Subject: [PATCH 017/109] [SPARK-7382] [MLLIB] Feature Parity in PySpark for ml.classification The missing pieces in ml.classification for Python! cc mengxr Author: Burak Yavuz Closes #6106 from brkyvz/ml-class and squashes the following commits: dd78237 [Burak Yavuz] fix style 1048e29 [Burak Yavuz] ready for PR --- python/pyspark/ml/classification.py | 478 +++++++++++++++++- .../ml/param/_shared_params_code_gen.py | 4 + python/pyspark/ml/param/shared.py | 29 ++ 3 files changed, 501 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8a009c4ac721f..96d29058a3781 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -17,17 +17,19 @@ from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ - HasRegParam +from pyspark.ml.param.shared import * +from pyspark.ml.regression import RandomForestParams from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', + 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam): + HasRegParam, HasTol, HasProbabilityCol): """ Logistic regression. @@ -50,25 +52,49 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.classification.LogisticRegression" + # a placeholder to make it appear in the generated doc + elasticNetParam = \ + Param(Params._dummy(), "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + threshold = Param(Params._dummy(), "threshold", + "threshold in binary classification prediction, in range [0, 1].") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1): + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1) + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability") """ super(LogisticRegression, self).__init__() - self._setDefault(maxIter=100, regParam=0.1) + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty + # is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = \ + Param(self, "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + + "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + #: param for threshold in binary classification prediction, in range [0, 1]. + self.threshold = Param(self, "threshold", + "threshold in binary classification prediction, in range [0, 1].") + self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1): + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1) + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability") Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs @@ -77,6 +103,45 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self.paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self.paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + self.paramMap[self.threshold] = value + return self + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + class LogisticRegressionModel(JavaModel): """ @@ -84,6 +149,399 @@ class LogisticRegressionModel(JavaModel): """ +class TreeClassifierParams(object): + """ + Private class to track supported impurity measures. + """ + supportedImpurities = ["entropy", "gini"] + + +class GBTParams(object): + """ + Private class to track supported GBT params. + """ + supportedLossTypes = ["logistic"] + + +@inherit_doc +class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") + >>> model = dt.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + _java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier" + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + """ + super(DecisionTreeClassifier, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini") + Sets params for the DecisionTreeClassifier. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return DecisionTreeClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self.paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +class DecisionTreeClassificationModel(JavaModel): + """ + Model fitted by DecisionTreeClassifier. + """ + + +@inherit_doc +class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Random_forest Random Forest` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed") + >>> model = rf.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + _java_class = "org.apache.spark.ml.classification.RandomForestClassifier" + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + numTrees=20, featureSubsetStrategy="auto", seed=42): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + numTrees=20, featureSubsetStrategy="auto", seed=42) + """ + super(RandomForestClassifier, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + #: param for Fraction of the training data used for learning each decision tree, + # in range (0, 1] + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: param for Number of trees to train (>= 1) + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") + #: param for The number of features to consider for splits at each tree node + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + impurity="gini", numTrees=20, featureSubsetStrategy="auto"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + Sets params for linear classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return RandomForestClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self.paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self.paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self.paramMap[self.numTrees] = value + return self + + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self.paramMap[self.featureSubsetStrategy] = value + return self + + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class RandomForestClassificationModel(JavaModel): + """ + Model fitted by RandomForestClassifier. + """ + + +@inherit_doc +class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + learning algorithm for classification. + It supports binary labels, as well as both continuous and categorical features. + Note: Multiclass labels are not currently supported. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") + >>> model = gbt.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + _java_class = "org.apache.spark.ml.classification.GBTClassifier" + # a placeholder to make it appear in the generated doc + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + + "contribution of each estimator") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", + maxIter=20, stepSize=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", + maxIter=20, stepSize=0.1) + """ + super(GBTClassifier, self).__init__() + #: param for Loss function which GBT tries to minimize (case-insensitive). + self.lossType = Param(self, "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + #: Fraction of the training data used for learning each decision tree, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of + # each estimator + self.stepSize = Param(self, "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator") + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1) + Sets params for Gradient Boosted Tree Classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GBTClassificationModel(java_model) + + def setLossType(self, value): + """ + Sets the value of :py:attr:`lossType`. + """ + self.paramMap[self.lossType] = value + return self + + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self.paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + self.paramMap[self.stepSize] = value + return self + + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + +class GBTClassificationModel(JavaModel): + """ + Model fitted by GBTClassifier. + """ + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 4a5cc6e64f023..6fa9b8c2cf367 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -109,6 +109,9 @@ def get$Name(self): ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), + ("probabilityCol", "Column name for predicted class conditional probabilities. " + + "Note: Not all models output well-calibrated probability estimates! These probabilities " + + "should be treated as confidences, not precise probabilities.", "'probability'"), ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"), ("inputCol", "input column name", None), ("inputCols", "input column names", None), @@ -156,6 +159,7 @@ def __init__(self): for name, doc in decisionTreeParams: variable = paramTemplate.replace("$name", name).replace("$doc", doc) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " + realParams += "#: param for " + doc + "\n " realParams += "self." + variable.replace("$owner", "self") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 779cabe853f8e..b116f05a068d3 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -165,6 +165,35 @@ def getPredictionCol(self): return self.getOrDefault(self.predictionCol) +class HasProbabilityCol(Params): + """ + Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + """ + + # a placeholder to make it appear in the generated doc + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + + def __init__(self): + super(HasProbabilityCol, self).__init__() + #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. + self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + if 'probability' is not None: + self._setDefault(probabilityCol='probability') + + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + self.paramMap[self.probabilityCol] = value + return self + + def getProbabilityCol(self): + """ + Gets the value of probabilityCol or its default value. + """ + return self.getOrDefault(self.probabilityCol) + + class HasRawPredictionCol(Params): """ Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. From 59250fe51486908f9e3f3d9ef10aadbcb9b4d62d Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 13 May 2015 16:13:48 -0700 Subject: [PATCH 018/109] [SPARK-7303] [SQL] push down project if possible when the child is sort Optimize the case of `project(_, sort)` , a example is: `select key from (select * from testData order by key) t` before this PR: ``` == Parsed Logical Plan == 'Project ['key] 'Subquery t 'Sort ['key ASC], true 'Project [*] 'UnresolvedRelation [testData], None == Analyzed Logical Plan == Project [key#0] Subquery t Sort [key#0 ASC], true Project [key#0,value#1] Subquery testData LogicalRDD [key#0,value#1], MapPartitionsRDD[1] == Optimized Logical Plan == Project [key#0] Sort [key#0 ASC], true LogicalRDD [key#0,value#1], MapPartitionsRDD[1] == Physical Plan == Project [key#0] Sort [key#0 ASC], true Exchange (RangePartitioning [key#0 ASC], 5), [] PhysicalRDD [key#0,value#1], MapPartitionsRDD[1] ``` after this PR ``` == Parsed Logical Plan == 'Project ['key] 'Subquery t 'Sort ['key ASC], true 'Project [*] 'UnresolvedRelation [testData], None == Analyzed Logical Plan == Project [key#0] Subquery t Sort [key#0 ASC], true Project [key#0,value#1] Subquery testData LogicalRDD [key#0,value#1], MapPartitionsRDD[1] == Optimized Logical Plan == Sort [key#0 ASC], true Project [key#0] LogicalRDD [key#0,value#1], MapPartitionsRDD[1] == Physical Plan == Sort [key#0 ASC], true Exchange (RangePartitioning [key#0 ASC], 5), [] Project [key#0] PhysicalRDD [key#0,value#1], MapPartitionsRDD[1] ``` with this rule we will first do column pruning on the table and then do sorting. Author: scwf This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #5838 from scwf/pruning and squashes the following commits: b00d833 [scwf] address michael's comment e230155 [scwf] fix tests failure b09b895 [scwf] improve column pruning --- .../sql/catalyst/optimizer/Optimizer.scala | 5 +++ .../optimizer/FilterPushdownSuite.scala | 36 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b163707cc9925..c2818d957cc79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -156,6 +156,11 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) + // push down project if possible when the child is sort + case p @ Project(projectList, s @ Sort(_, _, grandChild)) + if s.references.subsetOf(p.outputSet) => + s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0c428f7231b8e..be33cb9bb8eaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.expressions.{Count, Explode} +import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -542,4 +542,38 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + + test("push down project past sort") { + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) + + } } From e683182c3e6347afdac0e5658487f80e5e054ef4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 13 May 2015 16:15:31 -0700 Subject: [PATCH 019/109] [SQL] Move some classes into packages that are more appropriate. JavaTypeInference into catalyst types.DateUtils into catalyst CacheManager into execution DefaultParserDialect into catalyst Author: Reynold Xin Closes #6108 from rxin/sql-rename and squashes the following commits: 3fc9613 [Reynold Xin] Fixed import ordering. 83d9ff4 [Reynold Xin] Fixed codegen tests. e271e86 [Reynold Xin] mima f4e24a6 [Reynold Xin] [SQL] Move some classes into packages that are more appropriate. --- project/MimaExcludes.scala | 5 ++- .../sql/catalyst/CatalystTypeConverters.scala | 1 + .../sql/catalyst}/JavaTypeInference.scala | 4 +- .../spark/sql/catalyst/ParserDialect.scala | 36 +++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 1 + .../{types => catalyst/util}/DateUtils.scala | 2 +- .../apache/spark/sql/types/UTF8String.scala | 17 +++++--- .../ExpressionEvaluationSuite.scala | 1 + .../scala/org/apache/spark/sql/Column.scala | 2 + .../org/apache/spark/sql/SQLContext.scala | 40 +------------------ .../sql/{ => execution}/CacheManager.scala | 5 ++- .../spark/sql/execution/pythonUdfs.scala | 5 ++- .../org/apache/spark/sql/functions.scala | 1 + .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- .../apache/spark/sql/json/JacksonParser.scala | 1 + .../org/apache/spark/sql/json/JsonRDD.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 1 + .../spark/sql/parquet/ParquetIOSuite.scala | 1 + .../spark/sql/hive/HiveInspectors.scala | 1 + .../apache/spark/sql/hive/TableReader.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- 24 files changed, 80 insertions(+), 57 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/JavaTypeInference.scala (99%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/DateUtils.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/CacheManager.scala (97%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f31f0e554eee9..fba7290dcb0b5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -123,7 +123,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetTestData$"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport") + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager") ) ++ Seq( // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index a13e2f36a1a1f..75a493b248f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -23,6 +23,7 @@ import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1ec874f79617c..625c8d3a62125 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.catalyst import java.beans.Introspector import java.lang.{Iterable => JIterable} @@ -24,10 +24,8 @@ import java.util.{Iterator => JIterator, Map => JMap} import scala.language.existentials import com.google.common.reflect.TypeToken - import org.apache.spark.sql.types._ - /** * Type-inference utilities for POJOs and Java collections. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 05a92b06f9fd9..554fb4eb25eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -31,3 +31,39 @@ abstract class ParserDialect { // this is the main function that will be implemented by sql parser. def parse(sqlText: String): LogicalPlan } + +/** + * Currently we support the default dialect named "sql", associated with the class + * [[DefaultParserDialect]] + * + * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: + * {{{ + *-- switch to "hiveql" dialect + * spark-sql>SET spark.sql.dialect=hiveql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- switch to "sql" dialect + * spark-sql>SET spark.sql.dialect=sql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- register the new SQL dialect + * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- register the non-exist SQL dialect + * spark-sql> SET spark.sql.dialect=NotExistedClass; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- Exception will be thrown and switch to dialect + *-- "sql" (for SQLContext) or + *-- "hiveql" (for HiveContext) + * }}} + */ +private[spark] class DefaultParserDialect extends ParserDialect { + @transient + protected val sqlParser = new SqlParser + + override def parse(sqlText: String): LogicalPlan = { + sqlParser.parse(sqlText) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index adf941ab2a45f..d8cf2b2e32435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d17af0e7ff87e..ecb4c4b68f904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -250,7 +250,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case Cast(child @ DateType(), StringType) => child.castOrNull(c => q"""org.apache.spark.sql.types.UTF8String( - org.apache.spark.sql.types.DateUtils.toString($c))""", + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) case Cast(child @ NumericType(), IntegerType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 18cba4cc46707..5f8c7354aede1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ object Literal { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index d36a49159b87f..3f92be4a55d7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import java.sql.Date import java.text.SimpleDateFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index fc02ba6c9c43e..bc9c37bf2d5d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -19,15 +19,18 @@ package org.apache.spark.sql.types import java.util.Arrays +import org.apache.spark.annotation.DeveloperApi + /** - * A UTF-8 String, as internal representation of StringType in SparkSQL + * :: DeveloperApi :: + * A UTF-8 String, as internal representation of StringType in SparkSQL * - * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, - * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. * - * Note: This is not designed for general use cases, should not be used outside SQL. + * Note: This is not designed for general use cases, should not be used outside SQL. */ - +@DeveloperApi final class UTF8String extends Ordered[UTF8String] with Serializable { private[this] var bytes: Array[Byte] = _ @@ -180,6 +183,10 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 04fd261d16aa3..5c4a1527c27c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 42f5bcda49cfb..8bf1320ccb71d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -346,6 +346,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ def when(condition: Column, value: Any):Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => @@ -374,6 +375,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ def otherwise(value: Any):Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a148c7cd2d3b..521f3dc821795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -40,7 +41,6 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.ParserDialect -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ @@ -50,42 +50,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.{Partition, SparkContext} -/** - * Currently we support the default dialect named "sql", associated with the class - * [[DefaultParserDialect]] - * - * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: - * {{{ - *-- switch to "hiveql" dialect - * spark-sql>SET spark.sql.dialect=hiveql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- switch to "sql" dialect - * spark-sql>SET spark.sql.dialect=sql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- register the new SQL dialect - * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- register the non-exist SQL dialect - * spark-sql> SET spark.sql.dialect=NotExistedClass; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- Exception will be thrown and switch to dialect - *-- "sql" (for SQLContext) or - *-- "hiveql" (for HiveContext) - * }}} - */ -private[spark] class DefaultParserDialect extends ParserDialect { - @transient - protected val sqlParser = new catalyst.SqlParser - - override def parse(sqlText: String): LogicalPlan = { - sqlParser.parse(sqlText) - } -} - /** * The entry point for working with structured data (rows and columns) in Spark. Allows the * creation of [[DataFrame]] objects as well as the execution of SQL queries. @@ -1276,7 +1240,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And) + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 18584c2dcf797..5fcc48a67948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -15,18 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) +private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) /** * Provides support in a SQLContext for caching query results and automatically using these cached diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dbc3837950e0..65dd7ba020fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,20 +19,21 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} -import org.apache.spark.rdd.RDD - import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.{Accumulator, Logging => SparkLogging} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 099e1d8f03272..4404ad8ad63a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -438,6 +438,7 @@ object functions { * }}} * * @group normal_funcs + * @since 1.4.0 */ def when(condition: Column, value: Any): Column = { CaseWhen(Seq(condition.expr, lit(value).expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index a03ade3881f59..40483d3ec7701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -25,9 +25,9 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ -import org.apache.spark.util.Utils private[sql] object JDBCRDD extends Logging { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index a8e69ae61174f..81611513582a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -26,6 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f62973d5fcfab..4c32710a17bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.Logging diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ec0e76cde6f7c..8cdbe076cbd85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 263fafba930ce..b06e3385980f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7c371dbc7d3c9..008443df216aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -35,6 +35,7 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 74ae984f34866..7c7666f6e4b7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b69312f0f8717..0b6f7a334a715 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -35,7 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.DateUtils +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.util.Utils /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1d6393a3fec85..eaa9d6aad1f31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -26,7 +28,6 @@ import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} import org.apache.spark.sql.parquet.FSBasedParquetRelation import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DefaultParserDialect, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) From f6e18388d993d99f768c6d547327e0720ec64224 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 16:27:48 -0700 Subject: [PATCH 020/109] [SPARK-7608] Clean up old state in RDDOperationGraphListener This is necessary for streaming and long-running Spark applications. zsxwing tdas Author: Andrew Or Closes #6125 from andrewor14/viz-listener-leak and squashes the following commits: 8660949 [Andrew Or] Fix thing + add tests 33c0843 [Andrew Or] Clean up old job state --- .../ui/scope/RDDOperationGraphListener.scala | 30 +++++-- .../RDDOperationGraphListenerSuite.scala | 87 +++++++++++++++++++ 2 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala index 2884a49f31122..f0f7007d77a14 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala @@ -27,11 +27,16 @@ import org.apache.spark.ui.SparkUI * A SparkListener that constructs a DAG of RDD operations. */ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListener { - private val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]] - private val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph] - private val stageIds = new mutable.ArrayBuffer[Int] + private[ui] val jobIdToStageIds = new mutable.HashMap[Int, Seq[Int]] + private[ui] val stageIdToGraph = new mutable.HashMap[Int, RDDOperationGraph] + + // Keep track of the order in which these are inserted so we can remove old ones + private[ui] val jobIds = new mutable.ArrayBuffer[Int] + private[ui] val stageIds = new mutable.ArrayBuffer[Int] // How many jobs or stages to retain graph metadata for + private val retainedJobs = + conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) private val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) @@ -50,15 +55,22 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen /** On job start, construct a RDDOperationGraph for each stage in the job for display later. */ override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobId = jobStart.jobId - val stageInfos = jobStart.stageInfos + jobIds += jobId + jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted - stageInfos.foreach { stageInfo => - stageIds += stageInfo.stageId - stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo) + // Remove state for old jobs + if (jobIds.size >= retainedJobs) { + val toRemove = math.max(retainedJobs / 10, 1) + jobIds.take(toRemove).foreach { id => jobIdToStageIds.remove(id) } + jobIds.trimStart(toRemove) } - jobIdToStageIds(jobId) = stageInfos.map(_.stageId).sorted + } - // Remove graph metadata for old stages + /** Remove graph metadata for old stages */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stageInfo = stageSubmitted.stageInfo + stageIds += stageInfo.stageId + stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo) if (stageIds.size >= retainedStages) { val toRemove = math.max(retainedStages / 10, 1) stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) } diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala new file mode 100644 index 0000000000000..619b38ac02676 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.scope + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo} + +class RDDOperationGraphListenerSuite extends FunSuite { + private var jobIdCounter = 0 + private var stageIdCounter = 0 + + /** Run a job with the specified number of stages. */ + private def runOneJob(numStages: Int, listener: RDDOperationGraphListener): Unit = { + assert(numStages > 0, "I will not run a job with 0 stages for you.") + val stageInfos = (0 until numStages).map { _ => + val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d") + listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo)) + stageIdCounter += 1 + stageInfo + } + listener.onJobStart(new SparkListenerJobStart(jobIdCounter, 0, stageInfos)) + jobIdCounter += 1 + } + + test("listener cleans up metadata") { + + val conf = new SparkConf() + .set("spark.ui.retainedStages", "10") + .set("spark.ui.retainedJobs", "10") + + val listener = new RDDOperationGraphListener(conf) + assert(listener.jobIdToStageIds.isEmpty) + assert(listener.stageIdToGraph.isEmpty) + assert(listener.jobIds.isEmpty) + assert(listener.stageIds.isEmpty) + + // Run a few jobs, but not enough for clean up yet + runOneJob(1, listener) + runOneJob(2, listener) + runOneJob(3, listener) + assert(listener.jobIdToStageIds.size === 3) + assert(listener.stageIdToGraph.size === 6) + assert(listener.jobIds.size === 3) + assert(listener.stageIds.size === 6) + + // Run a few more, but this time the stages should be cleaned up, but not the jobs + runOneJob(5, listener) + runOneJob(100, listener) + assert(listener.jobIdToStageIds.size === 5) + assert(listener.stageIdToGraph.size === 9) + assert(listener.jobIds.size === 5) + assert(listener.stageIds.size === 9) + + // Run a few more, but this time both jobs and stages should be cleaned up + (1 to 100).foreach { _ => + runOneJob(1, listener) + } + assert(listener.jobIdToStageIds.size === 9) + assert(listener.stageIdToGraph.size === 9) + assert(listener.jobIds.size === 9) + assert(listener.stageIds.size === 9) + + // Ensure we clean up old jobs and stages, not arbitrary ones + assert(!listener.jobIdToStageIds.contains(0)) + assert(!listener.stageIdToGraph.contains(0)) + assert(!listener.stageIds.contains(0)) + assert(!listener.jobIds.contains(0)) + } + +} From f88ac701552a1a854247509db49d78f13515eae4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 16:28:37 -0700 Subject: [PATCH 021/109] [SPARK-7399] Spark compilation error for scala 2.11 Subsequent fix following #5966. I tried this out locally. Author: Andrew Or Closes #6129 from andrewor14/211-compilation and squashes the following commits: 713868f [Andrew Or] Fix compilation issue for scala 2.11 --- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../apache/spark/rdd/RDDOperationScope.scala | 4 ++-- .../spark/rdd/RDDOperationScopeSuite.scala | 20 ++++++++++--------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 02a94baf372d9..f7fa37e4cdcdc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1524,7 +1524,7 @@ abstract class RDD[T: ClassTag]( * doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint(): Unit = { - RDDOperationScope.withScope(sc, "checkpoint", false, true) { + RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) { if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 93ec606f2de7d..2725826f421f4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -96,7 +96,7 @@ private[spark] object RDDOperationScope { sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName - withScope[T](sc, callerMethodName, allowNesting)(body) + withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) } /** @@ -116,7 +116,7 @@ private[spark] object RDDOperationScope { sc: SparkContext, name: String, allowNesting: Boolean, - ignoreParent: Boolean = false)(body: => T): T = { + ignoreParent: Boolean)(body: => T): T = { // Save the old scope to restore it later val scopeKey = SparkContext.RDD_SCOPE_KEY val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala index d75ecbf1f0b4d..db465a6a9eb55 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDOperationScopeSuite.scala @@ -61,11 +61,11 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { var rdd1: MyCoolRDD = null var rdd2: MyCoolRDD = null var rdd3: MyCoolRDD = null - RDDOperationScope.withScope(sc, "scope1", allowNesting = false) { + RDDOperationScope.withScope(sc, "scope1", allowNesting = false, ignoreParent = false) { rdd1 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope2", allowNesting = false) { + RDDOperationScope.withScope(sc, "scope2", allowNesting = false, ignoreParent = false) { rdd2 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope3", allowNesting = false) { + RDDOperationScope.withScope(sc, "scope3", allowNesting = false, ignoreParent = false) { rdd3 = new MyCoolRDD(sc) } } @@ -84,11 +84,13 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { var rdd1: MyCoolRDD = null var rdd2: MyCoolRDD = null var rdd3: MyCoolRDD = null - RDDOperationScope.withScope(sc, "scope1", allowNesting = true) { // allow nesting here + // allow nesting here + RDDOperationScope.withScope(sc, "scope1", allowNesting = true, ignoreParent = false) { rdd1 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope2", allowNesting = false) { // stop nesting here + // stop nesting here + RDDOperationScope.withScope(sc, "scope2", allowNesting = false, ignoreParent = false) { rdd2 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope3", allowNesting = false) { + RDDOperationScope.withScope(sc, "scope3", allowNesting = false, ignoreParent = false) { rdd3 = new MyCoolRDD(sc) } } @@ -107,11 +109,11 @@ class RDDOperationScopeSuite extends FunSuite with BeforeAndAfter { var rdd1: MyCoolRDD = null var rdd2: MyCoolRDD = null var rdd3: MyCoolRDD = null - RDDOperationScope.withScope(sc, "scope1", allowNesting = true) { + RDDOperationScope.withScope(sc, "scope1", allowNesting = true, ignoreParent = false) { rdd1 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope2", allowNesting = true) { + RDDOperationScope.withScope(sc, "scope2", allowNesting = true, ignoreParent = false) { rdd2 = new MyCoolRDD(sc) - RDDOperationScope.withScope(sc, "scope3", allowNesting = true) { + RDDOperationScope.withScope(sc, "scope3", allowNesting = true, ignoreParent = false) { rdd3 = new MyCoolRDD(sc) } } From 44403414d3e754f7b991c0bbeb4868edb4135aa2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 16:29:10 -0700 Subject: [PATCH 022/109] [SPARK-7464] DAG visualization: highlight the same RDDs on hover This is pretty useful for MLlib. Author: Andrew Or Closes #6100 from andrewor14/dag-viz-hover and squashes the following commits: fefe2af [Andrew Or] Link tooltips for nodes that belong to the same RDD 90c6a7e [Andrew Or] Assign classes to clusters and nodes, not IDs --- .../apache/spark/ui/static/dagre-d3.min.js | 2 +- .../apache/spark/ui/static/spark-dag-viz.css | 4 +- .../apache/spark/ui/static/spark-dag-viz.js | 47 ++++++++++++++----- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index acf2d93b718b2..c55f752620dfd 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -20,7 +20,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("id",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ +module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 18c72694f3e2d..1846acb742b98 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -57,7 +57,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.job g.cluster[id*="stage"] rect { +#dag-viz-graph svg.job g.cluster[class*="stage"] rect { fill: #FFFFFF; stroke: #FF99AC; stroke-width: 1px; @@ -79,7 +79,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.stage g.cluster[id*="stage"] rect { +#dag-viz-graph svg.stage g.cluster[class*="stage"] rect { fill: #FFFFFF; stroke: #FFA6B6; stroke-width: 1px; diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index f7d0d3c61457c..e2ec00b9c3c0d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -108,7 +108,7 @@ function toggleDagViz(forJob) { * Output DOM hierarchy: * div#dag-viz-graph > * svg > - * g#cluster_stage_[stageId] + * g.cluster_stage_[stageId] * * Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz. * Any changes in the input format here must be reflected there. @@ -137,7 +137,7 @@ function renderDagViz(forJob) { // Find cached RDDs and mark them as such metadataContainer().selectAll(".cached-rdd").each(function(v) { var nodeId = VizConstants.nodePrefix + d3.select(this).text(); - svg.selectAll("#" + nodeId).classed("cached", true); + svg.selectAll("g." + nodeId).classed("cached", true); }); resizeSvg(svg); @@ -192,14 +192,10 @@ function renderDagVizForJob(svgContainer) { if (i > 0) { var existingStages = svgContainer .selectAll("g.cluster") - .filter("[id*=\"" + VizConstants.stageClusterPrefix + "\"]"); + .filter("[class*=\"" + VizConstants.stageClusterPrefix + "\"]"); if (!existingStages.empty()) { var lastStage = d3.select(existingStages[0].pop()); - var lastStageId = lastStage.attr("id"); - var lastStageWidth = toFloat(svgContainer - .select("#" + lastStageId) - .select("rect") - .attr("width")); + var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); var lastStagePosition = getAbsolutePosition(lastStage); var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep; container.attr("transform", "translate(" + offset + ", 0)"); @@ -372,14 +368,14 @@ function getAbsolutePosition(d3selection) { function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) { var fromNodeId = VizConstants.nodePrefix + fromRDDId; var toNodeId = VizConstants.nodePrefix + toRDDId; - var fromPos = getAbsolutePosition(svgContainer.select("#" + fromNodeId)); - var toPos = getAbsolutePosition(svgContainer.select("#" + toNodeId)); + var fromPos = getAbsolutePosition(svgContainer.select("g." + fromNodeId)); + var toPos = getAbsolutePosition(svgContainer.select("g." + toNodeId)); // On the job page, RDDs are rendered as dots (circles). When rendering the path, // we need to account for the radii of these circles. Otherwise the arrow heads // will bleed into the circle itself. var delta = toFloat(svgContainer - .select("g.node#" + toNodeId) + .select("g.node." + toNodeId) .select("circle") .attr("r")); if (fromPos.x < toPos.x) { @@ -431,10 +427,35 @@ function addTooltipsForRDDs(svgContainer) { node.select("circle") .attr("data-toggle", "tooltip") .attr("data-placement", "bottom") - .attr("title", tooltipText) + .attr("title", tooltipText); } + // Link tooltips for all nodes that belong to the same RDD + node.on("mouseenter", function() { triggerTooltipForRDD(node, true); }); + node.on("mouseleave", function() { triggerTooltipForRDD(node, false); }); }); - $("[data-toggle=tooltip]").tooltip({container: "body"}); + + $("[data-toggle=tooltip]") + .filter("g.node circle") + .tooltip({ container: "body", trigger: "manual" }); +} + +/* + * (Job page only) Helper function to show or hide tooltips for all nodes + * in the graph that refer to the same RDD the specified node represents. + */ +function triggerTooltipForRDD(d3node, show) { + var classes = d3node.node().classList; + for (var i = 0; i < classes.length; i++) { + var clazz = classes[i]; + var isRDDClass = clazz.indexOf(VizConstants.nodePrefix) == 0; + if (isRDDClass) { + graphContainer().selectAll("g." + clazz).each(function() { + var circle = d3.select(this).select("circle").node(); + var showOrHide = show ? "show" : "hide"; + $(circle).tooltip(showOrHide); + }); + } + } } /* Helper function to convert attributes to numeric values. */ From aa1837875a3febad2f22b91a294f91749852b42f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 16:29:52 -0700 Subject: [PATCH 023/109] [SPARK-7502] DAG visualization: gracefully handle removed stages Old stages are removed without much feedback to the user. This happens very often in streaming. See screenshots below for more detail. zsxwing **Before** ------------------------- **After** Author: Andrew Or Closes #6132 from andrewor14/dag-viz-remove-gracefully and squashes the following commits: 43175cd [Andrew Or] Handle removed jobs and stages gracefully --- .../apache/spark/ui/static/spark-dag-viz.css | 4 ++++ .../apache/spark/ui/static/spark-dag-viz.js | 18 +++++++++++++----- .../ui/scope/RDDOperationGraphListener.scala | 11 ++++++++--- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 1846acb742b98..eedefb44b96fc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -44,6 +44,10 @@ stroke-width: 1px; } +#dag-viz-graph div#empty-dag-viz-message { + margin: 15px; +} + /* Job page specific styles */ #dag-viz-graph svg.job marker#marker-arrow path { diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index e2ec00b9c3c0d..8138eb0d4f390 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -86,7 +86,7 @@ function toggleDagViz(forJob) { $(arrowSelector).toggleClass('arrow-open'); var shouldShow = $(arrowSelector).hasClass("arrow-open"); if (shouldShow) { - var shouldRender = graphContainer().select("svg").empty(); + var shouldRender = graphContainer().select("*").empty(); if (shouldRender) { renderDagViz(forJob); } @@ -117,10 +117,18 @@ function renderDagViz(forJob) { // If there is not a dot file to render, fail fast and report error var jobOrStage = forJob ? "job" : "stage"; - if (metadataContainer().empty()) { - graphContainer() - .append("div") - .text("No visualization information available for this " + jobOrStage); + if (metadataContainer().empty() || + metadataContainer().selectAll("div").empty()) { + var message = + "No visualization information available for this " + jobOrStage + "!
" + + "If this is an old " + jobOrStage + ", its visualization metadata may have been " + + "cleaned up over time.
You may consider increasing the value of "; + if (forJob) { + message += "spark.ui.retainedJobs and spark.ui.retainedStages."; + } else { + message += "spark.ui.retainedStages"; + } + graphContainer().append("div").attr("id", "empty-dag-viz-message").html(message); return; } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala index f0f7007d77a14..3b77a1e12cc45 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala @@ -42,9 +42,14 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen /** Return the graph metadata for the given stage, or None if no such information exists. */ def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { - jobIdToStageIds.get(jobId) - .map { sids => sids.flatMap { sid => stageIdToGraph.get(sid) } } - .getOrElse { Seq.empty } + val stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty } + val graphs = stageIds.flatMap { sid => stageIdToGraph.get(sid) } + // If the metadata for some stages have been removed, do not bother rendering this job + if (stageIds.size != graphs.size) { + Seq.empty + } else { + graphs + } } /** Return the graph metadata for the given stage, or None if no such information exists. */ From bb6dec3b160b54488892a509965fee70a530deff Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 16:31:24 -0700 Subject: [PATCH 024/109] [STREAMING] [MINOR] Keep streaming.UIUtils private zsxwing Author: Andrew Or Closes #6134 from andrewor14/private-streaming-uiutils and squashes the following commits: 225df94 [Andrew Or] Privatize class --- .../src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index c206f973b2c66..f153ee105a18e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.ui import java.util.concurrent.TimeUnit -object UIUtils { +private[streaming] object UIUtils { /** * Return the short string for a `TimeUnit`. From 61d1e87c0d3d12dac0b724d1b84436f748227e99 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 13 May 2015 16:43:30 -0700 Subject: [PATCH 025/109] [SPARK-7356] [STREAMING] Fix flakey tests in FlumePollingStreamSuite using SparkSink's batch CountDownLatch. This is meant to make the FlumePollingStreamSuite deterministic. Now we basically count the number of batches that have been completed - and then verify the results rather than sleeping for random periods of time. Author: Hari Shreedharan Closes #5918 from harishreedharan/flume-test-fix and squashes the following commits: 93f24f3 [Hari Shreedharan] Add an eventually block to ensure that all received data is processed. Refactor the dstream creation and remove redundant code. 1108804 [Hari Shreedharan] [SPARK-7356][STREAMING] Fix flakey tests in FlumePollingStreamSuite using SparkSink's batch CountDownLatch. --- .../flume/FlumePollingStreamSuite.scala | 110 ++++++++---------- 1 file changed, 51 insertions(+), 59 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 43c1b865b64a1..93afe50c2134f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,15 +18,18 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} +import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder +import org.scalatest.concurrent.Eventually._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -57,11 +60,11 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging before(beforeFunction()) - ignore("flume polling test") { + test("flume polling test") { testMultipleTimes(testFlumePolling) } - ignore("flume polling test multiple hosts") { + test("flume polling test multiple hosts") { testMultipleTimes(testFlumePollingMultipleHost) } @@ -100,18 +103,8 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging Configurables.configure(sink, context) sink.setChannel(channel) sink.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - ssc.start() - writeAndVerify(Seq(channel), ssc, outputBuffer) + writeAndVerify(Seq(sink), Seq(channel)) assertChannelIsEmpty(channel) sink.stop() channel.stop() @@ -142,10 +135,22 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + try { + writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + } finally { + sink.stop() + sink2.stop() + channel.stop() + channel2.stop() + } + } + def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 5) @@ -155,61 +160,49 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging outputStream.register() ssc.start() - try { - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() - } - } - - def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val executor = Executors.newCachedThreadPool() val executorCompletion = new ExecutorCompletionService[Void](executor) - channels.map(channel => { + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { executorCompletion.submit(new TxnSubmitter(channel, clock)) }) + for (i <- 0 until channels.size) { executorCompletion.take() } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < batchCount * channels.size && - System.currentTimeMillis() - startTime < 15000) { - logInfo("output.size = " + outputBuffer.size) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenedBuffer = outputBuffer.flatten + assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + + String.valueOf(i)).getBytes("utf-8"), + Map[String, String]("test-" + i.toString -> "header")) + var found = false + var j = 0 + while (j < flattenedBuffer.size && !found) { + val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") + if (new String(eventToVerify.getBody, "utf-8") == strToCompare && + eventToVerify.getHeaders.get("test-" + i.toString) + .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { + found = true + counter += 1 + } + j += 1 } - j += 1 } + assert(counter === totalEventsPerChannel * channels.size) } - assert(counter === totalEventsPerChannel * channels.size) + ssc.stop() } def assertChannelIsEmpty(channel: MemoryChannel): Unit = { @@ -234,7 +227,6 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging tx.commit() tx.close() Thread.sleep(500) // Allow some time for the events to reach - clock.advance(batchDuration.milliseconds) } null } From 73bed408fbb47dfc28063afa3898c27fbdec7735 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 May 2015 17:07:31 -0700 Subject: [PATCH 026/109] [SPARK-7081] Faster sort-based shuffle path using binary processing cache-aware sort This patch introduces a new shuffle manager that enhances the existing sort-based shuffle with a new cache-friendly sort algorithm that operates directly on binary data. The goals of this patch are to lower memory usage and Java object overheads during shuffle and to speed up sorting. It also lays groundwork for follow-up patches that will enable end-to-end processing of serialized records. The new shuffle manager, `UnsafeShuffleManager`, can be enabled by setting `spark.shuffle.manager=tungsten-sort` in SparkConf. The new shuffle manager uses directly-managed memory to implement several performance optimizations for certain types of shuffles. In cases where the new performance optimizations cannot be applied, the new shuffle manager delegates to SortShuffleManager to handle those shuffles. UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - The shuffle dependency specifies no aggregation or output ordering. - The shuffle serializer supports relocation of serialized values (this is currently supported by KryoSerializer and Spark SQL's custom serializers). - The shuffle produces fewer than 16777216 output partitions. - No individual record is larger than 128 MB when serialized. In addition, extra spill-merging optimizations are automatically applied when the shuffle compression codec supports concatenation of serialized streams. This is currently supported by Spark's LZF serializer. At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. In sort-based shuffle, incoming records are sorted according to their target partition ids, then written to a single map output file. Reducers fetch contiguous regions of this file in order to read their portion of the map output. In cases where the map output data is too large to fit in memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged to produce the final output file. UnsafeShuffleManager optimizes this process in several ways: - Its sort operates on serialized binary data rather than Java objects, which reduces memory consumption and GC overheads. This optimization requires the record serializer to have certain properties to allow serialized records to be re-ordered without requiring deserialization. See SPARK-4550, where this optimization was first proposed and implemented, for more details. - It uses a specialized cache-efficient sorter (UnsafeShuffleExternalSorter) that sorts arrays of compressed record pointers and partition ids. By using only 8 bytes of space per record in the sorting array, this fits more of the array into cache. - The spill merging procedure operates on blocks of serialized records that belong to the same partition and does not need to deserialize records during the merge. - When the spill compression codec supports concatenation of compressed data, the spill merge simply concatenates the serialized and compressed spill partitions to produce the final output partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used and avoids the need to allocate decompression or copying buffers during the merge. The shuffle read path is unchanged. This patch is similar to [SPARK-4550](http://issues.apache.org/jira/browse/SPARK-4550) / #4450 but uses a slightly different implementation. The `unsafe`-based implementation featured in this patch lays the groundwork for followup patches that will enable sorting to operate on serialized data pages that will be prepared by Spark SQL's new `unsafe` operators (such as the new aggregation operator introduced in #5725). ### Future work There are several tasks that build upon this patch, which will be left to future work: - [SPARK-7271](https://issues.apache.org/jira/browse/SPARK-7271) Redesign / extend the shuffle interfaces to accept binary data as input. The goal here is to let us bypass serialization steps in cases where the sort input is produced by an operator that operates directly on binary data. - Extension / redesign of the `Serializer` API. We can add new methods which allow serializers to determine the size requirements for serializing objects and for serializing objects directly to a specified memory address (similar to how `UnsafeRowConverter` works in Spark SQL). [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5868) Author: Josh Rosen Closes #5868 from JoshRosen/unsafe-sort and squashes the following commits: ef0a86e [Josh Rosen] Fix scalastyle errors 7610f2f [Josh Rosen] Add tests for proper cleanup of shuffle data. d494ffe [Josh Rosen] Fix deserialization of JavaSerializer instances. 52a9981 [Josh Rosen] Fix some bugs in the address packing code. 51812a7 [Josh Rosen] Change shuffle manager sort name to tungsten-sort 4023fa4 [Josh Rosen] Add @Private annotation to some Java classes. de40b9d [Josh Rosen] More comments to try to explain metrics code df07699 [Josh Rosen] Attempt to clarify confusing metrics update code 5e189c6 [Josh Rosen] Track time spend closing / flushing files; split TimeTrackingOutputStream into separate file. d5779c6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort c2ce78e [Josh Rosen] Fix a missed usage of MAX_PARTITION_ID e3b8855 [Josh Rosen] Cleanup in UnsafeShuffleWriter 4a2c785 [Josh Rosen] rename 'sort buffer' to 'pointer array' 6276168 [Josh Rosen] Remove ability to disable spilling in UnsafeShuffleExternalSorter. 57312c9 [Josh Rosen] Clarify fileBufferSize units 2d4e4f4 [Josh Rosen] Address some minor comments in UnsafeShuffleExternalSorter. fdcac08 [Josh Rosen] Guard against overflow when expanding sort buffer. 85da63f [Josh Rosen] Cleanup in UnsafeShuffleSorterIterator. 0ad34da [Josh Rosen] Fix off-by-one in nextInt() call 56781a1 [Josh Rosen] Rename UnsafeShuffleSorter to UnsafeShuffleInMemorySorter e995d1a [Josh Rosen] Introduce MAX_SHUFFLE_OUTPUT_PARTITIONS. e58a6b4 [Josh Rosen] Add more tests for PackedRecordPointer encoding. 4f0b770 [Josh Rosen] Attempt to implement proper shuffle write metrics. d4e6d89 [Josh Rosen] Update to bit shifting constants 69d5899 [Josh Rosen] Remove some unnecessary override vals 8531286 [Josh Rosen] Add tests that automatically trigger spills. 7c953f9 [Josh Rosen] Add test that covers UnsafeShuffleSortDataFormat.swap(). e1855e5 [Josh Rosen] Fix a handful of misc. IntelliJ inspections 39434f9 [Josh Rosen] Avoid integer multiplication overflow in getMemoryUsage (thanks FindBugs!) 1e3ad52 [Josh Rosen] Delete unused ByteBufferOutputStream class. ea4f85f [Josh Rosen] Roll back an unnecessary change in Spillable. ae538dc [Josh Rosen] Document UnsafeShuffleManager. ec6d626 [Josh Rosen] Add notes on maximum # of supported shuffle partitions. 0d4d199 [Josh Rosen] Bump up shuffle.memoryFraction to make tests pass. b3b1924 [Josh Rosen] Properly implement close() and flush() in DummySerializerInstance. 1ef56c7 [Josh Rosen] Revise compression codec support in merger; test cross product of configurations. b57c17f [Josh Rosen] Disable some overly-verbose logs that rendered DEBUG useless. f780fb1 [Josh Rosen] Add test demonstrating which compression codecs support concatenation. 4a01c45 [Josh Rosen] Remove unnecessary log message 27b18b0 [Josh Rosen] That for inserting records AT the max record size. fcd9a3c [Josh Rosen] Add notes + tests for maximum record / page sizes. 9d1ee7c [Josh Rosen] Fix MiMa excludes for ShuffleWriter change fd4bb9e [Josh Rosen] Use own ByteBufferOutputStream rather than Kryo's 67d25ba [Josh Rosen] Update Exchange operator's copying logic to account for new shuffle manager 8f5061a [Josh Rosen] Strengthen assertion to check partitioning 01afc74 [Josh Rosen] Actually read data in UnsafeShuffleWriterSuite 1929a74 [Josh Rosen] Update to reflect upstream ShuffleBlockManager -> ShuffleBlockResolver rename. e8718dd [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort 9b7ebed [Josh Rosen] More defensive programming RE: cleaning up spill files and memory after errors 7cd013b [Josh Rosen] Begin refactoring to enable proper tests for spilling. 722849b [Josh Rosen] Add workaround for transferTo() bug in merging code; refactor tests. 9883e30 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort b95e642 [Josh Rosen] Refactor and document logic that decides when to spill. 1ce1300 [Josh Rosen] More minor cleanup 5e8cf75 [Josh Rosen] More minor cleanup e67f1ea [Josh Rosen] Remove upper type bound in ShuffleWriter interface. cfe0ec4 [Josh Rosen] Address a number of minor review comments: 8a6fe52 [Josh Rosen] Rename UnsafeShuffleSpillWriter to UnsafeShuffleExternalSorter 11feeb6 [Josh Rosen] Update TODOs related to shuffle write metrics. b674412 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort aaea17b [Josh Rosen] Add comments to UnsafeShuffleSpillWriter. 4f70141 [Josh Rosen] Fix merging; now passes UnsafeShuffleSuite tests. 133c8c9 [Josh Rosen] WIP towards testing UnsafeShuffleWriter. f480fb2 [Josh Rosen] WIP in mega-refactoring towards shuffle-specific sort. 57f1ec0 [Josh Rosen] WIP towards packed record pointers for use in optimized shuffle sort. 69232fd [Josh Rosen] Enable compressible address encoding for off-heap mode. 7ee918e [Josh Rosen] Re-order imports in tests 3aeaff7 [Josh Rosen] More refactoring and cleanup; begin cleaning iterator interfaces 3490512 [Josh Rosen] Misc. cleanup f156a8f [Josh Rosen] Hacky metrics integration; refactor some interfaces. 2776aca [Josh Rosen] First passing test for ExternalSorter. 5e100b2 [Josh Rosen] Super-messy WIP on external sort 595923a [Josh Rosen] Remove some unused variables. 8958584 [Josh Rosen] Fix bug in calculating free space in current page. f17fa8f [Josh Rosen] Add missing newline c2fca17 [Josh Rosen] Small refactoring of SerializerPropertiesSuite to enable test re-use: b8a09fe [Josh Rosen] Back out accidental log4j.properties change bfc12d3 [Josh Rosen] Add tests for serializer relocation property. 240864c [Josh Rosen] Remove PrefixComputer and require prefix to be specified as part of insert() 1433b42 [Josh Rosen] Store record length as int instead of long. 026b497 [Josh Rosen] Re-use a buffer in UnsafeShuffleWriter 0748458 [Josh Rosen] Port UnsafeShuffleWriter to Java. 87e721b [Josh Rosen] Renaming and comments d3cc310 [Josh Rosen] Flag that SparkSqlSerializer2 supports relocation e2d96ca [Josh Rosen] Expand serializer API and use new function to help control when new UnsafeShuffle path is used. e267cee [Josh Rosen] Fix compilation of UnsafeSorterSuite 9c6cf58 [Josh Rosen] Refactor to use DiskBlockObjectWriter. 253f13e [Josh Rosen] More cleanup 8e3ec20 [Josh Rosen] Begin code cleanup. 4d2f5e1 [Josh Rosen] WIP 3db12de [Josh Rosen] Minor simplification and sanity checks in UnsafeSorter 767d3ca [Josh Rosen] Fix invalid range in UnsafeSorter. e900152 [Josh Rosen] Add test for empty iterator in UnsafeSorter 57a4ea0 [Josh Rosen] Make initialSize configurable in UnsafeSorter abf7bfe [Josh Rosen] Add basic test case. 81d52c5 [Josh Rosen] WIP on UnsafeSorter --- core/pom.xml | 10 + .../unsafe/DummySerializerInstance.java | 93 ++++ .../shuffle/unsafe/PackedRecordPointer.java | 92 +++ .../spark/shuffle/unsafe/SpillInfo.java | 37 ++ .../unsafe/UnsafeShuffleExternalSorter.java | 422 ++++++++++++++ .../unsafe/UnsafeShuffleInMemorySorter.java | 124 +++++ .../unsafe/UnsafeShuffleSortDataFormat.java | 67 +++ .../shuffle/unsafe/UnsafeShuffleWriter.java | 438 +++++++++++++++ .../storage/TimeTrackingOutputStream.java | 75 +++ .../scala/org/apache/spark/SparkEnv.scala | 3 +- .../spark/serializer/JavaSerializer.scala | 2 + .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../apache/spark/shuffle/ShuffleWriter.scala | 7 +- .../shuffle/hash/HashShuffleWriter.scala | 2 +- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../shuffle/sort/SortShuffleWriter.scala | 2 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 205 +++++++ .../spark/storage/BlockObjectWriter.scala | 24 +- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../util/collection/ExternalSorter.scala | 2 +- .../unsafe/PackedRecordPointerSuite.java | 101 ++++ .../UnsafeShuffleInMemorySorterSuite.java | 132 +++++ .../unsafe/UnsafeShuffleWriterSuite.java | 527 ++++++++++++++++++ .../spark/io/CompressionCodecSuite.scala | 44 ++ .../serializer/JavaSerializerSuite.scala | 29 + .../unsafe/UnsafeShuffleManagerSuite.scala | 128 +++++ .../shuffle/unsafe/UnsafeShuffleSuite.scala | 105 ++++ pom.xml | 14 +- project/MimaExcludes.scala | 6 + .../apache/spark/sql/execution/Exchange.scala | 28 +- unsafe/pom.xml | 4 + .../unsafe/memory/TaskMemoryManager.java | 79 ++- .../unsafe/memory/TaskMemoryManagerSuite.java | 23 + 33 files changed, 2767 insertions(+), 64 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java create mode 100644 core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java create mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java create mode 100644 core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java create mode 100644 core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 262a3320db106..bfa49d0d6dc25 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -361,6 +361,16 @@ junit test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + com.novocode junit-interface diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java new file mode 100644 index 0000000000000..3f746b886bc9b --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.reflect.ClassTag; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + throw new UnsupportedOperationException(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 0000000000000..4ee6a82c0423e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). + */ +final class PackedRecordPointer { + + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java new file mode 100644 index 0000000000000..7bac0dc0bbeb6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + */ +final class SpillInfo { + final long[] partitionLengths; + final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java new file mode 100644 index 0000000000000..9e9ed94b7890c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class UnsafeShuffleExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final int initialSize; + private final int numPartitions; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + private final LinkedList spills = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeShuffleInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.initialSize = initialSize; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + + this.writeMetrics = writeMetrics; + initializeForWriting(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + sorter.getSortedIterator(); + + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. + BlockObjectWriter writer; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + private long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + int partitionId) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + freeSpaceInCurrentPage -= 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + freeSpaceInCurrentPage -= lengthInBytes; + sorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + try { + if (sorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java new file mode 100644 index 0000000000000..5bab501da9364 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +final class UnsafeShuffleInMemorySorter { + + private final Sorter sorter; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); + + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; + + /** + * The position in the pointer array where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeShuffleInMemorySorter(int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize]; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 1 < pointerArray.length; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); + } else { + expandPointerArray(); + } + } + pointerArray[pointerArrayInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + pointerArrayInsertPosition++; + } + + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + + private final long[] pointerArray; + private final int numRecords; + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; + + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + } + + public boolean hasNext() { + return position < numRecords; + } + + public void loadNext() { + packedRecordPointer.set(pointerArray[position]); + position++; + } + } + + /** + * Return an iterator over record pointers in sorted order. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 0000000000000..a66d74ee44782 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.set(data[pos]); + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..ad7eb04afcd8c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Iterator; +import javax.annotation.Nullable; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConversions; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.annotation.Private; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +@Private +public class UnsafeShuffleWriter extends ShuffleWriter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + + private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + + private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf sparkConf) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConversions.asScalaIterator(records)); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (!success) { + sorter.cleanupAfterError(); + } + } + } + + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + INITIAL_SORT_BUFFER_SIZE, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } + } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], + bytesToTransfer, + mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; + } + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + Closeables.close(mergedFileOutputChannel, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java new file mode 100644 index 0000000000000..dc2aa30466cc6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; + +/** + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. + */ +@Private +public final class TimeTrackingOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final OutputStream outputStream; + + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + final long startTime = System.nanoTime(); + outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void close() throws IOException { + final long startTime = System.nanoTime(); + outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0c4d28f786edd..a5d831c7e68ad 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -313,7 +313,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index dfbde7c8a1b0d..698d1384d580d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6ad427bcac7f9..6c3b3080d2605 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index f6e6fe5defe09..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { +private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit + @throws[IOException] + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 897f0a5dc5bcc..eb87cee15903c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V]( writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index add2656294ca2..c9dd6bfc4c219 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") sorter = new ExternalSorter[K, V, C]( diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 0000000000000..f2bfef376d3ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.keyOrdering.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") + } + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + } + + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 8bc4e205bc3c6..a33f22ef52687 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter( extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - override def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - callWithTiming(out.write(b, off, len)) - } - override def close(): Unit = out.close() - override def flush(): Unit = out.flush() - } /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter( throw new IllegalStateException("Writer already closed. Cannot be reopened.") } fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) + ts = new TimeTrackingOutputStream(writeMetrics, fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializerInstance.serializeStream(bs) @@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - callWithTiming { - fos.getFD.sync() - } + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter( reportedPosition = pos } - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - // For testing private[spark] override def flush() { objOut.flush() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index b850973145077..df2d6ad3b41a4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7d5cf7b61e56a..3b9d14f9372b6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java new file mode 100644 index 0000000000000..db9e82759090a --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; + +public class PackedRecordPointerSuite { + + @Test + public void heap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void offHeap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void maximumPartitionIdCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID)); + assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId()); + } + + @Test + public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + try { + // Pointers greater than the maximum partition ID will overflow or trigger an assertion error + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); + assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId()); + } catch (AssertionError e ) { + // pass + } + } + + @Test + public void maximumOffsetInPageCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(address, packedPointer.getRecordPointer()); + } + + @Test + public void offsetsPastMaxOffsetInPageWillOverflow() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(0, packedPointer.getRecordPointer()); + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java new file mode 100644 index 0000000000000..8fa72597db24d --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Arrays; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeShuffleInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testBasicSorting() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + // Write the records into the data page and store pointers into the sorter + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); + } + + // Sort the records + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int prevPartitionId = -1; + Arrays.sort(dataToSort); + for (int i = 0; i < dataToSort.length; i++) { + Assert.assertTrue(iter.hasNext()); + iter.loadNext(); + final int partitionId = iter.packedRecordPointer.getPartitionId(); + Assert.assertTrue(partitionId >= 0 && partitionId <= 3); + Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, + partitionId >= prevPartitionId); + final long recordAddress = iter.packedRecordPointer.getRecordPointer(); + final int recordLength = PlatformDependent.UNSAFE.getInt( + memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); + final String str = getStringFromDataPage( + memoryManager.getPage(recordAddress), + memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length + recordLength); + Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1); + } + Assert.assertFalse(iter.hasNext()); + } + + @Test + public void testSortingManyNumbers() throws Exception { + UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + int[] numbersToSort = new int[128000]; + Random random = new Random(16); + for (int i = 0; i < numbersToSort.length; i++) { + numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); + sorter.insertRecord(0, numbersToSort[i]); + } + Arrays.sort(numbersToSort); + int[] sorterResult = new int[numbersToSort.length]; + UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int j = 0; + while (iter.hasNext()) { + iter.loadNext(); + sorterResult[j] = iter.packedRecordPointer.getPartitionId(); + j += 1; + } + Assert.assertArrayEquals(numbersToSort, sorterResult); + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java new file mode 100644 index 0000000000000..730d265c87f88 --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.ByteBuffer; +import java.util.*; + +import scala.*; +import scala.collection.Iterator; +import scala.reflect.ClassTag; +import scala.runtime.AbstractFunction1; + +import com.google.common.collect.HashMultiset; +import com.google.common.io.ByteStreams; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.serializer.*; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeShuffleWriterSuite { + + static final int NUM_PARTITITONS = 4; + final TaskMemoryManager taskMemoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + File mergedOutputFile; + File tempDir; + long[] partitionSizesInMergedFile; + final LinkedList spillFilesCreated = new LinkedList(); + SparkConf conf; + final Serializer serializer = new KryoSerializer(new SparkConf()); + TaskMetrics taskMetrics; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + + private final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); + } else { + return stream; + } + } + } + + @After + public void tearDown() { + Utils.deleteRecursively(tempDir); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } + } + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + tempDir = Utils.createTempDir("test", "test"); + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + partitionSizesInMergedFile = null; + spillFilesCreated.clear(); + conf = new SparkConf(); + taskMetrics = new TaskMetrics(); + + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( + new Answer() { + @Override + public InputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + InputStream is = (InputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); + } else { + return is; + } + } + } + ); + + when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( + new Answer() { + @Override + public OutputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + OutputStream os = (OutputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); + } else { + return os; + } + } + } + ); + + when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + return null; + } + }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer>() { + @Override + public Tuple2 answer( + InvocationOnMock invocationOnMock) throws Throwable { + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + + when(taskContext.taskMetrics()).thenReturn(taskMetrics); + + when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled) throws IOException { + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, shuffleDep), + 0, // map id + taskContext, + conf + ); + } + + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + private List> readRecordsFromFile() throws IOException { + final ArrayList> recordsList = new ArrayList>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream in = new FileInputStream(mergedOutputFile); + ByteStreams.skipFully(in, startOffset); + in = new LimitedInputStream(in, partitionSize); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + + @Test(expected=IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() throws IOException { + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { + createWriter(false).stop(false); + } + + @Test + public void writeEmptyIterator() throws Exception { + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(Collections.>emptyIterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + } + + @Test + public void writeWithoutSpilling() throws Exception { + // In this example, each partition should have exactly one record: + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + dataToWrite.add(new Tuple2(i, i)); + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(dataToWrite.iterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + // All partitions should be the same size: + assertEquals(partitionSizesInMergedFile[0], size); + sumOfPartitionSizes += size; + } + assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + private void testMergingSpills( + boolean transferToEnabled, + String compressionCodecName) throws IOException { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } + final UnsafeShuffleWriter writer = createWriter(transferToEnabled); + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null); + } + + @Test + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null); + } + + @Test + public void writeEnoughDataToTriggerSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to allocate new data page + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; + for (int i = 0; i < 128 + 1; i++) { + dataToWrite.add(new Tuple2(i, bigByteArray)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to grow sort buffer + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + // Use a custom serializer so that we have exact control over the size of serialized data. + final Serializer byteArraySerializer = new Serializer() { + @Override + public SerializerInstance newInstance() { + return new SerializerInstance() { + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + byte[] bytes = (byte[]) t; + try { + s.write(bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public void close() { } + }; + } + public ByteBuffer serialize(T t, ClassTag ev1) { return null; } + public DeserializationStream deserializeStream(InputStream s) { return null; } + public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } + public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } + }; + } + }; + when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); + final UnsafeShuffleWriter writer = createWriter(false); + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); + writer.forceSorterToSpill(); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + new Random(42).nextBytes(atMaxRecordSize); + writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); + writer.forceSorterToSpill(); + // Inserting a record that's larger than the max record size should fail: + final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + new Random(42).nextBytes(exceedsMaxRecordSize); + Product2 hugeRecord = + new Tuple2(new byte[0], exceedsMaxRecordSize); + try { + // Here, we write through the public `write()` interface instead of the test-only + // `insertRecordIntoSorter` interface: + writer.write(Collections.singletonList(hugeRecord).iterator()); + fail("Expected exception to be thrown"); + } catch (IOException e) { + // Pass + } + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + final UnsafeShuffleWriter writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 8c6035fb367fe..cf6a143537889 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import com.google.common.io.ByteStreams import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) + assert(codec.getClass === classOf[LZ4CompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) @@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lzf supports concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) + assert(codec.getClass === classOf[LZFCompressionCodec]) + testConcatenationOfSerializedStreams(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) @@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("snappy does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("bad compression codec") { intercept[IllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") } } + + private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = { + val bytes1: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (0 to 64).foreach(out.write) + out.close() + baos.toByteArray + } + val bytes2: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (65 to 127).foreach(out.write) + out.close() + baos.toByteArray + } + val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2)) + val decompressed: Array[Byte] = new Array[Byte](128) + ByteStreams.readFully(concatenatedBytes, decompressed) + assert(decompressed.toSeq === (0 to 127)) + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala new file mode 100644 index 0000000000000..ed4d8ce632e16 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + +class JavaSerializerSuite extends FunSuite { + test("JavaSerializer instances are serializable") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(serializer)) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala new file mode 100644 index 0000000000000..49a04a2a45280 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark._ +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} + +/** + * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are + * performed in other suites. + */ +class UnsafeShuffleManagerSuite extends FunSuite with Matchers { + + import UnsafeShuffleManager.canUseUnsafeShuffle + + private class RuntimeExceptionAnswer extends Answer[Object] { + override def answer(invocation: InvocationOnMock): Object = { + throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName) + } + } + + private def shuffleDep( + partitioner: Partitioner, + serializer: Option[Serializer], + keyOrdering: Option[Ordering[Any]], + aggregator: Option[Aggregator[Any, Any, Any]], + mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = { + val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer()) + doReturn(0).when(dep).shuffleId + doReturn(partitioner).when(dep).partitioner + doReturn(serializer).when(dep).serializer + doReturn(keyOrdering).when(dep).keyOrdering + doReturn(aggregator).when(dep).aggregator + doReturn(mapSideCombine).when(dep).mapSideCombine + dep + } + + test("supported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) + when(rangePartitioner.numPartitions).thenReturn(2) + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = rangePartitioner, + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + } + + test("unsupported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + val java = Some(new JavaSerializer(new SparkConf())) + + // We only support serializers that support object relocation + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = java, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles with more than 16 million output partitions + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = None, + mapSideCombine = false + ))) + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = true + ))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala new file mode 100644 index 0000000000000..6351539e91e97 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + +class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. + + override def beforeAll() { + conf.set("spark.shuffle.manager", "tungsten-sort") + // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort + // shuffle records. + conf.set("spark.shuffle.memoryFraction", "0.5") + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } +} diff --git a/pom.xml b/pom.xml index cf9279ea5a2a6..564a443466e5a 100644 --- a/pom.xml +++ b/pom.xml @@ -669,7 +669,7 @@ org.mockito mockito-all - 1.9.0 + 1.9.5 test @@ -684,6 +684,18 @@ 4.10 test + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + com.novocode junit-interface diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fba7290dcb0b5..487062a31f77f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -131,6 +131,12 @@ object MimaExcludes { // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") ) case v if v.startsWith("1.3") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c3d2c7019a54a..3e46596ecf6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair object Exchange { @@ -85,7 +86,9 @@ case class Exchange( // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf - val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || + shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (newOrdering.nonEmpty) { @@ -93,11 +96,11 @@ case class Exchange( // which requires a defensive copy. true } else if (sortBasedShuffleOn) { - // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory. - // However, there are two special cases where we can avoid the copy, described below: - if (partitioner.numPartitions <= bypassMergeThreshold) { - // If the number of output partitions is sufficiently small, then Spark will fall back to - // the old hash-based shuffle write path which doesn't buffer deserialized records. + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { @@ -105,9 +108,14 @@ case class Exchange( // them. This optimization is guarded by a feature-flag and is only applied in cases where // shuffle dependency does not specify an ordering and the record serializer has certain // properties. If this optimization is enabled, we can safely avoid the copy. + // + // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // None of the special cases held, so we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code + // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls + // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In + // both cases, we must copy. true } } else { diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 5b0733206b2bc..9e151fc7a9141 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -42,6 +42,10 @@ com.google.code.findbugs jsr305 + + com.google.guava + guava + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 9224988e6ad69..2906ac8abad1a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,10 +48,18 @@ public final class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -101,11 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (logger.isTraceEnabled()) { - logger.trace("Allocating {} byte page", size); - } - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -120,8 +127,8 @@ public MemoryBlock allocatePage(long size) { final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; - if (logger.isDebugEnabled()) { - logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } @@ -130,9 +137,6 @@ public MemoryBlock allocatePage(long size) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { - if (logger.isTraceEnabled()) { - logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); - } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; executorMemoryManager.free(page); @@ -140,8 +144,8 @@ public void freePage(MemoryBlock page) { allocatedPages.clear(page.pageNumber); } pageTable[page.pageNumber] = null; - if (logger.isDebugEnabled()) { - logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } } @@ -173,14 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (inHeap) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } else { - return offsetInPage; + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -189,7 +215,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -204,10 +230,15 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + return offsetInPage; } else { - return pagePlusOffsetAddress; + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + return pageTable[pageNumber].getBaseOffset() + offsetInPage; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 932882f1ca248..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -38,4 +38,27 @@ public void leakedPageMemoryIsDetected() { Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + } From 59aaa1dad6bee06e38ee5c03bdf82354242286ee Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Wed, 13 May 2015 17:24:04 -0700 Subject: [PATCH 027/109] [SPARK-7601] [SQL] Support Insert into JDBC Datasource Supported InsertableRelation for JDBC Datasource JDBCRelation. Example usage: sqlContext.sql( s""" |CREATE TEMPORARY TABLE testram1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'testram1', user 'xx', password 'xx', driver 'com.h2.Driver') """.stripMargin.replaceAll("\n", " ")) sqlContext.sql("insert into table testram1 select * from testsrc") sqlContext.sql("insert overwrite table testram1 select * from testsrc") Author: Venkata Ramana Gollamudi Closes #6121 from gvramana/JDBCDatasource_insert and squashes the following commits: f3fb5f1 [Venkata Ramana Gollamudi] Support for JDBC Datasource InsertableRelation --- .../apache/spark/sql/jdbc/JDBCRelation.scala | 8 +++- .../spark/sql/jdbc/JDBCWriteSuite.scala | 37 ++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index d6b3fb3291a2e..93e82549f213b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ @@ -129,7 +130,8 @@ private[sql] case class JDBCRelation( parts: Array[Partition], properties: Properties = new Properties())(@transient val sqlContext: SQLContext) extends BaseRelation - with PrunedFilteredScan { + with PrunedFilteredScan + with InsertableRelation { override val needConversion: Boolean = false @@ -148,4 +150,8 @@ private[sql] case class JDBCRelation( filters, parts) } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + data.insertIntoJDBC(url, table, overwrite, properties) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index f3ce8e66460e5..0800eded443de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -43,6 +43,29 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() + conn1.prepareStatement("drop table if exists test.people").executeUpdate() + conn1.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() + conn1.prepareStatement("drop table if exists test.people1").executeUpdate() + conn1.prepareStatement( + "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.commit() + + TestSQLContext.sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + TestSQLContext.sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE1 + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) } after { @@ -114,5 +137,17 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) } } - + + test("INSERT to JDBC Datasource") { + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + + test("INSERT to JDBC Datasource with overwrite") { + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } } From bce00dac403d3be2be59218b7b93a56c34c68f1a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 13 May 2015 17:33:15 -0700 Subject: [PATCH 028/109] [SPARK-6752] [STREAMING] [REVISED] Allow StreamingContext to be recreated from checkpoint and existing SparkContext This is a revision of the earlier version (see #5773) that passed the active SparkContext explicitly through a new set of Java and Scala API. The drawbacks are. * Hard to implement in python. * New API introduced. This is even more confusing since we are introducing getActiveOrCreate in SPARK-7553 Furthermore, there is now a direct way get an existing active SparkContext or create a new on - SparkContext.getOrCreate(conf). Its better to use this to get the SparkContext rather than have a new API to explicitly pass the context. So in this PR I have * Removed the new versions of StreamingContext.getOrCreate() which took SparkContext * Added the ability to pick up existing SparkContext when the StreamingContext tries to create a SparkContext. Author: Tathagata Das Closes #6096 from tdas/SPARK-6752 and squashes the following commits: 53f4b2d [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-6752 f024b77 [Tathagata Das] Removed extra API and used SparkContext.getOrCreate --- .../spark/streaming/StreamingContext.scala | 49 +------------ .../api/java/JavaStreamingContext.scala | 45 ------------ .../apache/spark/streaming/JavaAPISuite.java | 25 +------ .../streaming/StreamingContextSuite.scala | 70 ++----------------- 4 files changed, 9 insertions(+), 180 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 407cab45ed4c6..1d2ecdd341813 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -134,7 +134,7 @@ class StreamingContext private[streaming] ( if (sc_ != null) { sc_ } else if (isCheckpointPresent) { - new SparkContext(cp_.createSparkConf()) + SparkContext.getOrCreate(cp_.createSparkConf()) } else { throw new SparkException("Cannot create StreamingContext without a SparkContext") } @@ -750,53 +750,6 @@ object StreamingContext extends Logging { checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext - ): StreamingContext = { - getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new StreamingContext if there is an - * error in reading checkpoint data. By default, an exception will be - * thrown on error. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext, - createOnError: Boolean - ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) - checkpointOption.map(new StreamingContext(sparkContext, _, null)) - .getOrElse(creatingFunc(sparkContext)) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index d8fbed2c50644..b639b94d5ca47 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -804,51 +804,6 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc, createOnError) - new JavaStreamingContext(ssc) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 2e00b980b9e44..1077b1b2cb7e3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1766,29 +1766,10 @@ public JavaStreamingContext call() { Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.set(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + JavaSparkContext sc = new JavaSparkContext(conf); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5f93332896de1..4b12affbb0ddd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -419,76 +419,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) - } - - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(ssc.conf.get("someKey") === "someValue", "checkpointed config not recovered") } - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path + // getOrCreate should recover StreamingContext with existing SparkContext testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + sc = new SparkContext(conf) + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") + assert(!ssc.conf.contains("someKey"), "checkpointed config unexpectedly recovered") } } From 32e27df412706b30daf41f9d46c5572bb9a41bdb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 13 May 2015 17:55:06 -0700 Subject: [PATCH 029/109] [HOTFIX] Bug in merge script --- dev/merge_spark_pr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index f952c9d0b15e2..1c126f50bf095 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -266,10 +266,9 @@ def get_version_json(version_str): resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] - custom_fields = {'resolution': {'id': resolution.raw['id']}} asf_jira.transition_issue( jira_id, resolve["id"], fixVersions = jira_fix_versions, - comment = comment, fields = custom_fields) + comment = comment, resolution = {'id': resolution.raw['id']}) print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) From 728af88cf6be4c25a732ab7e4fe66c1ed0041164 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 13 May 2015 17:58:29 -0700 Subject: [PATCH 030/109] [HOTFIX] Use 'new Job' in fsBasedParquet.scala Same issue as #6095 cc liancheng Author: zsxwing Closes #6136 from zsxwing/hotfix and squashes the following commits: 4beea54 [zsxwing] Use 'new Job' in fsBasedParquet.scala --- .../scala/org/apache/spark/sql/parquet/fsBasedParquet.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala index d810d6a028c58..c83a9c35dbddf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala @@ -231,7 +231,7 @@ private[sql] class FSBasedParquetRelation( filters: Array[Filter], inputPaths: Array[String]): RDD[Row] = { - val job = Job.getInstance(SparkHadoopUtil.get.conf) + val job = new Job(SparkHadoopUtil.get.conf) val conf = ContextUtil.getConfiguration(job) ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) From 3113da9c7067bbf90639866ae9d946f02cc484ff Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 May 2015 21:04:13 -0700 Subject: [PATCH 031/109] [HOT FIX #6125] Do not wait for all stages to start rendering zsxwing Author: Andrew Or Closes #6138 from andrewor14/dag-viz-clean-properly and squashes the following commits: 19d4e98 [Andrew Or] Add synchronize 02542d6 [Andrew Or] Rename overloaded variable d11bee1 [Andrew Or] Don't wait until all stages have started before rendering --- .../ui/scope/RDDOperationGraphListener.scala | 34 ++++++++++--------- .../RDDOperationGraphListenerSuite.scala | 1 - 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala index 3b77a1e12cc45..aa9c25cb5c8c6 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala @@ -41,11 +41,11 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) /** Return the graph metadata for the given stage, or None if no such information exists. */ - def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { - val stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty } - val graphs = stageIds.flatMap { sid => stageIdToGraph.get(sid) } + def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized { + val _stageIds = jobIdToStageIds.get(jobId).getOrElse { Seq.empty } + val graphs = _stageIds.flatMap { sid => stageIdToGraph.get(sid) } // If the metadata for some stages have been removed, do not bother rendering this job - if (stageIds.size != graphs.size) { + if (_stageIds.size != graphs.size) { Seq.empty } else { graphs @@ -53,16 +53,29 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen } /** Return the graph metadata for the given stage, or None if no such information exists. */ - def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = { + def getOperationGraphForStage(stageId: Int): Option[RDDOperationGraph] = synchronized { stageIdToGraph.get(stageId) } /** On job start, construct a RDDOperationGraph for each stage in the job for display later. */ override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobId = jobStart.jobId + val stageInfos = jobStart.stageInfos + jobIds += jobId jobIdToStageIds(jobId) = jobStart.stageInfos.map(_.stageId).sorted + stageInfos.foreach { stageInfo => + stageIds += stageInfo.stageId + stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo) + // Remove state for old stages + if (stageIds.size >= retainedStages) { + val toRemove = math.max(retainedStages / 10, 1) + stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) } + stageIds.trimStart(toRemove) + } + } + // Remove state for old jobs if (jobIds.size >= retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) @@ -71,15 +84,4 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen } } - /** Remove graph metadata for old stages */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - val stageInfo = stageSubmitted.stageInfo - stageIds += stageInfo.stageId - stageIdToGraph(stageInfo.stageId) = RDDOperationGraph.makeOperationGraph(stageInfo) - if (stageIds.size >= retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stageIds.take(toRemove).foreach { id => stageIdToGraph.remove(id) } - stageIds.trimStart(toRemove) - } - } } diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index 619b38ac02676..c659fc1e8b9a9 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -31,7 +31,6 @@ class RDDOperationGraphListenerSuite extends FunSuite { assert(numStages > 0, "I will not run a job with 0 stages for you.") val stageInfos = (0 until numStages).map { _ => val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d") - listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo)) stageIdCounter += 1 stageInfo } From d5f18de1657bfabf5493011e0b2c7ec29c02c64c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 May 2015 21:27:17 -0700 Subject: [PATCH 032/109] [SPARK-7612] [MLLIB] update NB training to use mllib's BLAS This is similar to the changes to k-means, which gives us better control on the performance. dbtsai Author: Xiangrui Meng Closes #6128 from mengxr/SPARK-7612 and squashes the following commits: b5c24c5 [Xiangrui Meng] merge master a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS --- .../mllib/classification/NaiveBayes.scala | 43 +++++++++---------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b381dc2cb0140..af24ab616663b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.numerics.{exp => brzExp, log => brzLog} - import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] ( val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * brzData) ) + labels(brzArgmax(brzPi + brzTheta * brzData)) case "Bernoulli" => if (!brzData.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") } - labels (brzArgmax (brzPi + + labels(brzArgmax(brzPi + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. @@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -288,10 +286,8 @@ class NaiveBayes private ( def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") @@ -300,10 +296,8 @@ class NaiveBayes private ( val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( @@ -314,21 +308,24 @@ class NaiveBayes private ( // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { if (modelType == "Bernoulli") { requireZeroOneBernoulliValues(v) } else { requireNonnegativeValues(v) } - (1L, v.toBreeze.toDenseVector) + (1L, v.copy.toDense) }, - mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) - (c._1 + 1L, c._2 += v.toBreeze) + BLAS.axpy(1.0, v, c._2) + (c._1 + 1L, c._2) }, - mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => - (c1._1 + c2._1, c1._2 += c2._2) + mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { + BLAS.axpy(1.0, c2._2, c1._2) + (c1._1 + c2._1, c1._2) + } ).collect() val numLabels = aggregated.length @@ -348,7 +345,7 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) case "Bernoulli" => math.log(n + 2.0 * lambda) case _ => // This should never happen. From d3db2fd66752e80865e9c7a75d8e8d945121697e Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 13 May 2015 22:23:21 -0700 Subject: [PATCH 033/109] [SPARK-7620] [ML] [MLLIB] Removed calling size, length in while condition to avoid extra JVM call Author: DB Tsai Closes #6137 from dbtsai/clean and squashes the following commits: 185816d [DB Tsai] fix compilication issue f418d08 [DB Tsai] first commit --- .../classification/LogisticRegression.scala | 9 ++-- .../apache/spark/ml/feature/Bucketizer.scala | 3 +- .../spark/ml/feature/VectorIndexer.scala | 6 ++- .../ml/regression/LinearRegression.scala | 6 ++- .../spark/mllib/feature/ChiSqSelector.scala | 3 +- .../spark/mllib/optimization/Updater.scala | 3 +- .../mllib/regression/IsotonicRegression.scala | 8 ++-- .../stat/MultivariateOnlineSummarizer.scala | 47 +++++++++++-------- .../spark/mllib/stat/test/ChiSqTest.scala | 6 ++- .../spark/mllib/tree/impurity/Impurity.scala | 14 +++--- .../mllib/util/LinearDataGenerator.scala | 3 +- .../spark/ml/feature/BucketizerSuite.scala | 6 ++- .../LogisticRegressionSuite.scala | 3 +- 13 files changed, 73 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 93ba91167bfad..2b103626873a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -258,7 +258,8 @@ class LogisticRegressionModel private[ml] ( rawPrediction match { case dv: DenseVector => var i = 0 - while (i < dv.size) { + val size = dv.size + while (i < size) { dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i))) i += 1 } @@ -357,7 +358,8 @@ private[classification] class MultiClassSummarizer extends Serializable { def histogram: Array[Long] = { val result = Array.ofDim[Long](numClasses) var i = 0 - while (i < result.length) { + val len = result.length + while (i < len) { result(i) = distinctMap.getOrElse(i, 0L) i += 1 } @@ -480,7 +482,8 @@ private class LogisticAggregator( var i = 0 val localThisGradientSumArray = this.gradientSumArray val localOtherGradientSumArray = other.gradientSumArray - while (i < localThisGradientSumArray.length) { + val len = localThisGradientSumArray.length + while (i < len) { localThisGradientSumArray(i) += localOtherGradientSumArray(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e52d797293cf3..d8f1961cb380a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -98,7 +98,8 @@ private[feature] object Bucketizer { false } else { var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { if (splits(i) >= splits(i + 1)) return false i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 2e6313ac14485..0f83a29c86bf6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -189,7 +189,8 @@ private object VectorIndexer { private def addDenseVector(dv: DenseVector): Unit = { var i = 0 - while (i < dv.size) { + val size = dv.size + while (i < size) { if (featureValueSets(i).size <= maxCategories) { featureValueSets(i).add(dv(i)) } @@ -201,7 +202,8 @@ private object VectorIndexer { // TODO: This might be able to handle 0's more efficiently. var vecIndex = 0 // index into vector var k = 0 // index into non-zero elements - while (vecIndex < sv.size) { + val size = sv.size + while (vecIndex < size) { val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { k += 1 sv.values(k - 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6377923afc0c4..36c242bb5f2a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -167,7 +167,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress val weights = { val rawWeights = state.x.toArray.clone() var i = 0 - while (i < rawWeights.length) { + val len = rawWeights.length + while (i < len) { rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } i += 1 } @@ -307,7 +308,8 @@ private class LeastSquaresAggregator( val weightsArray = weights.toArray.clone() var sum = 0.0 var i = 0 - while (i < weightsArray.length) { + val len = weightsArray.length + while (i < len) { if (featuresStd(i) != 0.0) { weightsArray(i) /= featuresStd(i) sum += weightsArray(i) * featuresMean(i) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c6057c7f837b1..9cc2d0ffcab7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -38,7 +38,8 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf protected def isSorted(array: Array[Int]): Boolean = { var i = 1 - while (i < array.length) { + val len = array.length + while (i < len) { if (array(i) < array(i-1)) return false i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 3ed3a5b9b3843..9f463e0cafb6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -116,7 +116,8 @@ class L1Updater extends Updater { // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize var i = 0 - while (i < brzWeights.length) { + val len = brzWeights.length + while (i < len) { val wi = brzWeights(i) brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index be2a00c2dfea4..4ce541ae5bed9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -69,7 +69,8 @@ class IsotonicRegressionModel ( /** Asserts the input array is monotone with the given ordering. */ private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { var i = 1 - while (i < xs.length) { + val len = xs.length + while (i < len) { require(ord.compare(xs(i - 1), xs(i)) <= 0, s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") i += 1 @@ -329,11 +330,12 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali } var i = 0 - while (i < input.length) { + val len = input.length + while (i < len) { var j = i // Find monotonicity violating sequence, if any. - while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + while (j < len - 1 && input(j)._1 > input(j + 1)._1) { j = j + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fcc2a148791bd..0b1755613aac4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -70,23 +70,30 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") + val localCurrMean= currMean + val localCurrM2n = currM2n + val localCurrM2 = currM2 + val localCurrL1 = currL1 + val localNnz = nnz + val localCurrMax = currMax + val localCurrMin = currMin sample.foreachActive { (index, value) => if (value != 0.0) { - if (currMax(index) < value) { - currMax(index) = value + if (localCurrMax(index) < value) { + localCurrMax(index) = value } - if (currMin(index) > value) { - currMin(index) = value + if (localCurrMin(index) > value) { + localCurrMin(index) = value } - val prevMean = currMean(index) + val prevMean = localCurrMean(index) val diff = value - prevMean - currMean(index) = prevMean + diff / (nnz(index) + 1.0) - currM2n(index) += (value - currMean(index)) * diff - currM2(index) += value * value - currL1(index) += math.abs(value) + localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) + localCurrM2n(index) += (value - localCurrMean(index)) * diff + localCurrM2(index) += value * value + localCurrL1(index) += math.abs(value) - nnz(index) += 1.0 + localNnz(index) += 1.0 } } @@ -130,14 +137,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.clone - this.currM2n = other.currM2n.clone - this.currM2 = other.currM2.clone - this.currL1 = other.currL1.clone + this.currMean = other.currMean.clone() + this.currM2n = other.currM2n.clone() + this.currM2 = other.currM2.clone() + this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt - this.nnz = other.nnz.clone - this.currMax = other.currMax.clone - this.currMin = other.currMin.clone + this.nnz = other.nnz.clone() + this.currMax = other.currMax.clone() + this.currMin = other.currMin.clone() } this } @@ -165,7 +172,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if (denominator > 0.0) { val deltaMean = currMean var i = 0 - while (i < currM2n.size) { + val len = currM2n.length + while (i < len) { realVariance(i) = currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt realVariance(i) /= denominator @@ -211,7 +219,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMagnitude = Array.ofDim[Double](n) var i = 0 - while (i < currM2.size) { + val len = currM2.length + while (i < len) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index ea82d39b72c03..e597fce2babd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -205,8 +205,10 @@ private[stat] object ChiSqTest extends Logging { val colSums = new Array[Double](numCols) val rowSums = new Array[Double](numRows) val colMajorArr = counts.toArray + val colMajorArrLen = colMajorArr.length + var i = 0 - while (i < colMajorArr.size) { + while (i < colMajorArrLen) { val elem = colMajorArr(i) if (elem < 0.0) { throw new IllegalArgumentException("Contingency table cannot contain negative entries.") @@ -220,7 +222,7 @@ private[stat] object ChiSqTest extends Logging { // second pass to collect statistic var statistic = 0.0 var j = 0 - while (j < colMajorArr.size) { + while (j < colMajorArrLen) { val col = j / numRows val colSum = colSums(col) if (colSum == 0.0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 60e2ab2bb829e..72eb24c49264a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -111,11 +111,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * Add the stats from another calculator into this one, modifying and returning this calculator. */ def add(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be added with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) += other.stats(i) i += 1 } @@ -127,11 +128,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * calculator. */ def subtract(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) -= other.stats(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index b1a4517344970..b4e33c98ba7e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -107,7 +107,8 @@ object LinearDataGenerator { x.foreach { v => var i = 0 - while (i < v.length) { + val len = v.length + while (i < len) { v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 1900820400aee..20d2f3ac6696b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -122,7 +122,8 @@ private object BucketizerSuite extends FunSuite { def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { if (feature < splits(i + 1)) return i i += 1 } @@ -138,7 +139,8 @@ private object BucketizerSuite extends FunSuite { s" ${splits.mkString(", ")}") } var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { // Split i should fall in bucket i. testFeature(splits(i), i) // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index fb0a194718802..966811a5a3263 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -101,7 +101,8 @@ object LogisticRegressionSuite { // This doesn't work if `vector` is a sparse vector. val vectorArray = vector.toArray var i = 0 - while (i < vectorArray.length) { + val len = vectorArray.length + while (i < len) { vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) i += 1 } From 13e652b61a81b2d2e94088006fbd5fd4ed383e3d Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 14 May 2015 00:23:27 -0700 Subject: [PATCH 034/109] [SPARK-7595] [SQL] Window will cause resolve failed with self join for example: table: src(key string, value string) sql: with v1 as(select key, count(value) over (partition by key) cnt_val from src), v2 as(select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) select * from v2 limit 5; then will analyze fail when resolving conflicting references in Join: 'Limit 5 'Project [*] 'Subquery v2 'Project ['v1.key,'v1_lag.cnt_val] 'Filter ('v1.key = 'v1_lag.key) 'Join Inner, None Subquery v1 Project [key#95,cnt_val#94L] Window [key#95,value#96], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount(value#96) WindowSpecDefinition [key#95], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt_val#94L], WindowSpecDefinition [key#95], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Project [key#95,value#96] MetastoreRelation default, src, None Subquery v1_lag Subquery v1 Project [key#97,cnt_val#94L] Window [key#97,value#98], [HiveWindowFunction#org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount(value#98) WindowSpecDefinition [key#97], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt_val#94L], WindowSpecDefinition [key#97], [], ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING Project [key#97,value#98] MetastoreRelation default, src, None Conflicting attributes: cnt_val#94L Author: linweizhong Closes #6114 from Sephiroth-Lin/spark-7595 and squashes the following commits: f8f2637 [linweizhong] Add unit test dfe9169 [linweizhong] Handle windowExpression with self join --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a4c61149dd975..4baeeb5b58c2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -322,6 +322,11 @@ class Analyzer( case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion @ Window(_, windowExpressions, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. sys.error( s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index eaa9d6aad1f31..5c7152e2140db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -763,4 +763,14 @@ class SQLQuerySuite extends QueryTest { sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) } + + test("SPARK-7595: Window will cause resolve failed with self join") { + checkAnswer(sql( + """ + |with + | v1 as (select key, count(value) over (partition by key) cnt_val from src), + | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) + | select * from v2 order by key limit 1 + """.stripMargin), Row(0, 3)) + } } From 1b8625f4258d6d1a049d0ba60e39e9757f5a568b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 14 May 2015 01:22:15 -0700 Subject: [PATCH 035/109] [SPARK-7407] [MLLIB] use uid + name to identify parameters A param instance is strongly attached to an parent in the current implementation. So if we make a copy of an estimator or a transformer in pipelines and other meta-algorithms, it becomes error-prone to copy the params to the copied instances. In this PR, a param is identified by its parent's UID and the param name. So it becomes loosely attached to its parent and all its derivatives. The UID is preserved during copying or fitting. All components now have a default constructor and a constructor that takes a UID as input. I keep the constructors for Param in this PR to reduce the amount of diff and moved `parent` as a mutable field. This PR still needs some clean-ups, and there are several spark.ml PRs pending. I'll try to get them merged first and then update this PR. jkbradley Author: Xiangrui Meng Closes #6019 from mengxr/SPARK-7407 and squashes the following commits: c4c8120 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 520f0a2 [Xiangrui Meng] address comments 2569168 [Xiangrui Meng] fix tests 873caca [Xiangrui Meng] fix tests in OneVsRest; fix a racing condition in shouldOwn 409ea08 [Xiangrui Meng] minor updates 83a163c [Xiangrui Meng] update JavaDeveloperApiExample 5db5325 [Xiangrui Meng] update OneVsRest 7bde7ae [Xiangrui Meng] merge master 697fdf9 [Xiangrui Meng] update Bucketizer 7b4f6c2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 629d402 [Xiangrui Meng] fix LRSuite 154516f [Xiangrui Meng] merge master aa4a611 [Xiangrui Meng] fix examples/compile a4794dd [Xiangrui Meng] change Param to use to reduce the size of diff fdbc415 [Xiangrui Meng] all tests passed c255f17 [Xiangrui Meng] fix tests in ParamsSuite 818e1db [Xiangrui Meng] merge master e1160cf [Xiangrui Meng] fix tests fbc39f0 [Xiangrui Meng] pass test:compile 108937e [Xiangrui Meng] pass compile 8726d39 [Xiangrui Meng] use parent uid in Param eaeed35 [Xiangrui Meng] update Identifiable --- .../examples/ml/JavaDeveloperApiExample.java | 43 +++++-- .../examples/ml/DeveloperApiExample.scala | 11 +- .../scala/org/apache/spark/ml/Model.scala | 10 +- .../scala/org/apache/spark/ml/Pipeline.scala | 11 +- .../scala/org/apache/spark/ml/Predictor.scala | 2 +- .../DecisionTreeClassifier.scala | 13 +- .../ml/classification/GBTClassifier.scala | 13 +- .../classification/LogisticRegression.scala | 11 +- .../spark/ml/classification/OneVsRest.scala | 27 ++-- .../RandomForestClassifier.scala | 13 +- .../BinaryClassificationEvaluator.scala | 7 +- .../apache/spark/ml/feature/Binarizer.scala | 7 +- .../apache/spark/ml/feature/Bucketizer.scala | 8 +- .../spark/ml/feature/ElementwiseProduct.scala | 6 +- .../apache/spark/ml/feature/HashingTF.scala | 5 +- .../org/apache/spark/ml/feature/IDF.scala | 10 +- .../apache/spark/ml/feature/Normalizer.scala | 5 +- .../spark/ml/feature/OneHotEncoder.scala | 8 +- .../ml/feature/PolynomialExpansion.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 10 +- .../spark/ml/feature/StringIndexer.scala | 10 +- .../apache/spark/ml/feature/Tokenizer.scala | 10 +- .../spark/ml/feature/VectorAssembler.scala | 6 +- .../spark/ml/feature/VectorIndexer.scala | 13 +- .../apache/spark/ml/feature/Word2Vec.scala | 10 +- .../org/apache/spark/ml/param/params.scala | 115 ++++++++++++------ .../apache/spark/ml/recommendation/ALS.scala | 10 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 13 +- .../ml/regression/LinearRegression.scala | 17 +-- .../ml/regression/RandomForestRegressor.scala | 12 +- .../spark/ml/tuning/CrossValidator.scala | 10 +- .../apache/spark/ml/util/Identifiable.scala | 21 +++- .../JavaLogisticRegressionSuite.java | 4 +- .../apache/spark/ml/param/JavaTestParams.java | 52 +++++--- .../regression/JavaLinearRegressionSuite.java | 4 +- .../spark/ml/util/IdentifiableSuite.scala | 40 ++++++ .../DecisionTreeClassifierSuite.scala | 2 +- .../classification/GBTClassifierSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 39 +++--- .../ml/classification/OneVsRestSuite.scala | 6 +- .../RandomForestClassifierSuite.scala | 2 +- .../apache/spark/ml/param/ParamsSuite.scala | 16 ++- .../apache/spark/ml/param/TestParams.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 2 +- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../RandomForestRegressorSuite.scala | 2 +- 47 files changed, 452 insertions(+), 213 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index eac4f898a475d..ec533d174ebdc 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.util.Identifiable$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -103,7 +104,23 @@ public static void main(String[] args) throws Exception { * However, this should still compile and run successfully. */ class MyJavaLogisticRegression - extends Classifier { + extends Classifier { + + public MyJavaLogisticRegression() { + init(); + } + + public MyJavaLogisticRegression(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; + } /** * Param for max number of iterations @@ -117,7 +134,7 @@ class MyJavaLogisticRegression int getMaxIter() { return (Integer) getOrDefault(maxIter); } - public MyJavaLogisticRegression() { + private void init() { setMaxIter(100); } @@ -137,7 +154,7 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(this, weights); + return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } } @@ -149,17 +166,21 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { * However, this should still compile and run successfully. */ class MyJavaLogisticRegressionModel - extends ClassificationModel { - - private MyJavaLogisticRegression parent_; - public MyJavaLogisticRegression parent() { return parent_; } + extends ClassificationModel { private Vector weights_; public Vector weights() { return weights_; } - public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) { - this.parent_ = parent_; - this.weights_ = weights_; + public MyJavaLogisticRegressionModel(String uid, Vector weights) { + this.uid_ = uid; + this.weights_ = weights; + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; } // This uses the default implementation of transform(), which reads column "features" and outputs @@ -204,6 +225,6 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 2a2d0677272a0..3ee456edbe01e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -20,6 +20,7 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -106,10 +107,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ -private class MyLogisticRegression +private class MyLogisticRegression(override val uid: String) extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] with MyLogisticRegressionParams { + def this() = this(Identifiable.randomUID("myLogReg")) + setMaxIter(100) // Initialize // The parameter setter is in this class since it should return type MyLogisticRegression. @@ -125,7 +128,7 @@ private class MyLogisticRegression val weights = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(this, weights) + new MyLogisticRegressionModel(uid, weights).setParent(this) } } @@ -135,7 +138,7 @@ private class MyLogisticRegression * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private class MyLogisticRegressionModel( - override val parent: MyLogisticRegression, + override val uid: String, val weights: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -173,6 +176,6 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(parent, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 9974efe7b1d25..7fd515369b19b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -32,7 +32,15 @@ abstract class Model[M <: Model[M]] extends Transformer { * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - val parent: Estimator[M] + var parent: Estimator[M] = _ + + /** + * Sets the parent of this model (Java API). + */ + def setParent(parent: Estimator[M]): M = { + this.parent = parent + this.asInstanceOf[M] + } override def copy(extra: ParamMap): M = { // The default implementation of Params.copy doesn't work for models. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 33d430f5671ee..fac54188f9f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -80,7 +81,9 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @AlphaComponent -class Pipeline extends Estimator[PipelineModel] { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] { + + def this() = this(Identifiable.randomUID("pipeline")) /** * param for pipeline stages @@ -148,7 +151,7 @@ class Pipeline extends Estimator[PipelineModel] { } } - new PipelineModel(this, transformers.toArray) + new PipelineModel(uid, transformers.toArray).setParent(this) } override def copy(extra: ParamMap): Pipeline = { @@ -171,7 +174,7 @@ class Pipeline extends Estimator[PipelineModel] { */ @AlphaComponent class PipelineModel private[ml] ( - override val parent: Pipeline, + override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { @@ -190,6 +193,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(parent, stages) + new PipelineModel(uid, stages) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index f6a5f27425d1f..ec0f76aa668bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -88,7 +88,7 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset)) + copyValues(train(dataset).setParent(this)) } override def copy(extra: ParamMap): Learner = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index dcebea1d4b015..7c961332bf5b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} @@ -39,10 +39,12 @@ import org.apache.spark.sql.DataFrame * features. */ @AlphaComponent -final class DecisionTreeClassifier +final class DecisionTreeClassifier(override val uid: String) extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + def this() = this(Identifiable.randomUID("dtc")) + // Override parameter setters from parent trait for Java API compatibility. override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) @@ -101,7 +103,7 @@ object DecisionTreeClassifier { */ @AlphaComponent final class DecisionTreeClassificationModel private[ml] ( - override val parent: DecisionTreeClassifier, + override val uid: String, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -114,7 +116,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) } override def toString: String = { @@ -138,6 +140,7 @@ private[ml] object DecisionTreeClassificationModel { s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeClassificationModel(parent, rootNode) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") + new DecisionTreeClassificationModel(uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ae51b05a0c42d..d504d84beb91e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} @@ -44,10 +44,12 @@ import org.apache.spark.sql.DataFrame * Note: Multiclass labels are not currently supported. */ @AlphaComponent -final class GBTClassifier +final class GBTClassifier(override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + def this() = this(Identifiable.randomUID("gbtc")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: @@ -160,7 +162,7 @@ object GBTClassifier { */ @AlphaComponent final class GBTClassificationModel( - override val parent: GBTClassifier, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTClassificationModel] @@ -184,7 +186,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) } override def toString: String = { @@ -210,6 +212,7 @@ private[ml] object GBTClassificationModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTClassificationModel(parent, newTrees, oldModel.treeWeights) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 2b103626873a9..8694c96e4c5b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -26,6 +26,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction} import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint @@ -50,10 +51,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * Currently, this class only supports binary classification. */ @AlphaComponent -class LogisticRegression +class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with Logging { + def this() = this(Identifiable.randomUID("logreg")) + /** * Set the regularization parameter. * Default is 0.0. @@ -213,7 +216,7 @@ class LogisticRegression (weightsWithIntercept, 0.0) } - new LogisticRegressionModel(this, weights.compressed, intercept) + new LogisticRegressionModel(uid, weights.compressed, intercept) } } @@ -224,7 +227,7 @@ class LogisticRegression */ @AlphaComponent class LogisticRegressionModel private[ml] ( - override val parent: LogisticRegression, + override val uid: String, val weights: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] @@ -276,7 +279,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(parent, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index afb8d75d57384..1543f051ccd17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -40,19 +40,17 @@ private[ml] trait OneVsRestParams extends PredictorParams { type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] - type E <: Classifier[F, E, M] + type E <: Classifier[F, E, M] } /** * param for the base binary classifier that we reduce multiclass classification into. * @group param */ - val classifier: Param[ClassifierType] = - new Param(this, "classifier", "base binary classifier ") + val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") /** @group getParam */ def getClassifier: ClassifierType = $(classifier) - } /** @@ -70,10 +68,10 @@ private[ml] trait OneVsRestParams extends PredictorParams { * (taking label 0). */ @AlphaComponent -class OneVsRestModel private[ml] ( - override val parent: OneVsRest, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) +final class OneVsRestModel private[ml] ( + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_,_]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -145,11 +143,13 @@ class OneVsRestModel private[ml] ( * is picked to label the example. */ @Experimental -final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { +final class OneVsRest(override val uid: String) + extends Estimator[OneVsRestModel] with OneVsRestParams { + + def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ - def setClassifier(value: Classifier[_,_,_]): this.type = { - // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed + def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } @@ -204,6 +204,7 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) case attr: Attribute => attr } - copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) + val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) + copyValues(model) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 9954893f14359..a1de7919859eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} @@ -41,10 +41,12 @@ import org.apache.spark.sql.DataFrame * features. */ @AlphaComponent -final class RandomForestClassifier +final class RandomForestClassifier(override val uid: String) extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + def this() = this(Identifiable.randomUID("rfc")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: @@ -118,7 +120,7 @@ object RandomForestClassifier { */ @AlphaComponent final class RandomForestClassificationModel private[ml] ( - override val parent: RandomForestClassifier, + override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -146,7 +148,7 @@ final class RandomForestClassificationModel private[ml] ( } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(parent, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees), extra) } override def toString: String = { @@ -172,6 +174,7 @@ private[ml] object RandomForestClassificationModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestClassificationModel(parent, newTrees) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") + new RandomForestClassificationModel(uid, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index e5a73c6087a11..c1af09c9694ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,10 @@ import org.apache.spark.sql.types.DoubleType * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent -class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol { +class BinaryClassificationEvaluator(override val uid: String) + extends Evaluator with HasRawPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("binEval")) /** * param for metric name in evaluation diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 6eb1db6971111..62f4a6343423e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -32,7 +32,10 @@ import org.apache.spark.sql.types.{DoubleType, StructType} * Binarize a column of continuous features given a threshold. */ @AlphaComponent -final class Binarizer extends Transformer with HasInputCol with HasOutputCol { +final class Binarizer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("binarizer")) /** * Param for threshold used to binarize continuous features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d8f1961cb380a..ac8dfb5632a7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -21,11 +21,11 @@ import java.{util => ju} import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -35,10 +35,10 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ @AlphaComponent -final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) +final class Bucketizer(override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol { - def this() = this(null) + def this() = this(Identifiable.randomUID("bucketizer")) /** * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index f8b56293e3ccc..8b32eee0e490a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -31,7 +32,10 @@ import org.apache.spark.sql.types.DataType * multiplier. */ @AlphaComponent -class ElementwiseProduct extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { +class ElementwiseProduct(override val uid: String) + extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { + + def this() = this(Identifiable.randomUID("elemProd")) /** * the vector to multiply with input vectors diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index c305a819a8966..30033ced68a04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -29,7 +30,9 @@ import org.apache.spark.sql.types.DataType * Maps a sequence of terms to their term frequencies using the hashing trick. */ @AlphaComponent -class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { +class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + + def this() = this(Identifiable.randomUID("hashingTF")) /** * Number of features. Should be > 0. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index d901a20aed002..788c392050c2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -62,7 +62,9 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @AlphaComponent -final class IDF extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { + + def this() = this(Identifiable.randomUID("idf")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -74,7 +76,7 @@ final class IDF extends Estimator[IDFModel] with IDFBase { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) - copyValues(new IDFModel(this, idf)) + copyValues(new IDFModel(uid, idf).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -88,7 +90,7 @@ final class IDF extends Estimator[IDFModel] with IDFBase { */ @AlphaComponent class IDFModel private[ml] ( - override val parent: IDF, + override val uid: String, idfModel: feature.IDFModel) extends Model[IDFModel] with IDFBase { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 755b46a64c7f1..3f689d1585cd6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -29,7 +30,9 @@ import org.apache.spark.sql.types.DataType * Normalize a vector to have unit norm using the given p-norm. */ @AlphaComponent -class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { +class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { + + def this() = this(Identifiable.randomUID("normalizer")) /** * Normalization in L^p^ space. Must be >= 1. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 46514ae5f0e84..1fb9b9ae75091 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribu import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -37,8 +37,10 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * linearly dependent because they sum up to one. */ @AlphaComponent -class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] - with HasInputCol with HasOutputCol { +class OneHotEncoder(override val uid: String) + extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("oneHot")) /** * Whether to include a component in the encoded vectors for the first category, defaults to true. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 9e6177ca27e4a..41564410e4965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -34,7 +35,10 @@ import org.apache.spark.sql.types.DataType * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ @AlphaComponent -class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { +class PolynomialExpansion(override val uid: String) + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + def this() = this(Identifiable.randomUID("poly")) /** * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 7cad59ff3fa37..5ccda15d872ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -55,7 +56,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with * statistics on the samples in the training set. */ @AlphaComponent -class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { +class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] + with StandardScalerParams { + + def this() = this(Identifiable.randomUID("stdScal")) setDefault(withMean -> false, withStd -> true) @@ -76,7 +80,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(this, scalerModel)) + copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -96,7 +100,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP */ @AlphaComponent class StandardScalerModel private[ml] ( - override val parent: StandardScaler, + override val uid: String, scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3d78537ad84cb..3f79b67309f07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{NumericType, StringType, StructType} @@ -58,7 +59,10 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * So the most frequent label gets index 0. */ @AlphaComponent -class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { +class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] + with StringIndexerBase { + + def this() = this(Identifiable.randomUID("strIdx")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -73,7 +77,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray - copyValues(new StringIndexerModel(this, labels)) + copyValues(new StringIndexerModel(uid, labels).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -87,7 +91,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase */ @AlphaComponent class StringIndexerModel private[ml] ( - override val parent: StringIndexer, + override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { private val labelToIndex: OpenHashMap[String, Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 649c217b16590..36d9e17eca41b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -27,7 +28,9 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * A tokenizer that converts the input string to lowercase and then splits it by white spaces. */ @AlphaComponent -class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { +class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { + + def this() = this(Identifiable.randomUID("tok")) override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") @@ -48,7 +51,10 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { * It returns an array of strings that can be empty. */ @AlphaComponent -class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { +class RegexTokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + def this() = this(Identifiable.randomUID("regexTok")) /** * Minimum token length, >= 0. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 796758a70ef18..1c0009476908c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +34,10 @@ import org.apache.spark.sql.types._ * A feature transformer that merges multiple columns into a vector column. */ @AlphaComponent -class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { +class VectorAssembler(override val uid: String) + extends Transformer with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("va")) /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 0f83a29c86bf6..6d1d0524e59ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.callUDF @@ -87,7 +87,10 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add option for allowing unknown categories. */ @AlphaComponent -class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { +class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] + with VectorIndexerParams { + + def this() = this(Identifiable.randomUID("vecIdx")) /** @group setParam */ def setMaxCategories(value: Int): this.type = set(maxCategories, value) @@ -110,7 +113,9 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara iter.foreach(localCatStats.addVector) Iterator(localCatStats) }.reduce((stats1, stats2) => stats1.merge(stats2)) - copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps)) + val model = new VectorIndexerModel(uid, numFeatures, categoryStats.getCategoryMaps) + .setParent(this) + copyValues(model) } override def transformSchema(schema: StructType): StructType = { @@ -238,7 +243,7 @@ private object VectorIndexer { */ @AlphaComponent class VectorIndexerModel private[ml] ( - override val parent: VectorIndexer, + override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) extends Model[VectorIndexerModel] with VectorIndexerParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 34ff92970129f..8ace8c53bb663 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -85,7 +85,9 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @AlphaComponent -final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { + + def this() = this(Identifiable.randomUID("w2v")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -122,7 +124,7 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { .setSeed($(seed)) .setVectorSize($(vectorSize)) .fit(input) - copyValues(new Word2VecModel(this, wordVectors)) + copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -136,7 +138,7 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { */ @AlphaComponent class Word2VecModel private[ml] ( - override val parent: Word2Vec, + override val uid: String, wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5a7ec29aac6cc..247e08be1bb15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -40,12 +40,17 @@ import org.apache.spark.ml.util.Identifiable * @tparam T param value type */ @AlphaComponent -class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean) +class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { - def this(parent: Params, name: String, doc: String) = + def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue[T]) + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** * Assert that the given value is valid for this parameter. * @@ -60,8 +65,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." + - s" Parameter description: $toString") + throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") } } @@ -75,19 +79,15 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal */ def ->(value: T): ParamPair[T] = ParamPair(this, value) - /** - * Converts this param's name, doc, and optionally its default value and the user-supplied - * value in its parent to string. - */ - override def toString: String = { - val valueStr = if (parent.isDefined(this)) { - val defaultValueStr = parent.getDefault(this).map("default: " + _) - val currentValueStr = parent.get(this).map("current: " + _) - (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") - } else { - "(undefined)" + override final def toString: String = s"${parent}__$name" + + override final def hashCode: Int = toString.## + + override final def equals(obj: Any): Boolean = { + obj match { + case p: Param[_] => (p.parent == parent) && (p.name == name) + case _ => false } - s"$name: $doc $valueStr" } } @@ -173,49 +173,71 @@ object ParamValidators { // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean) +class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) extends Param[Double](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean) +class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) extends Param[Int](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean) +class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) extends Param[Float](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean) +class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) extends Param[Long](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid +class BooleanParam(parent: String, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -265,6 +287,9 @@ trait Params extends Identifiable with Serializable { /** * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. + * + * Note: Developer should not use this method in constructor because we cannot guarantee that + * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods @@ -299,15 +324,36 @@ trait Params extends Identifiable with Serializable { * those are checked during schema validation. */ def validateParams(): Unit = { - params.filter(isDefined _).foreach { param => + params.filter(isDefined).foreach { param => param.asInstanceOf[Param[Any]].validate($(param)) } } /** - * Returns the documentation of all params. + * Explains a param. + * @param param input param, must belong to this instance. + * @return a string that contains the input param name, doc, and optionally its default value and + * the user-supplied value + */ + def explainParam(param: Param[_]): String = { + shouldOwn(param) + val valueStr = if (isDefined(param)) { + val defaultValueStr = getDefault(param).map("default: " + _) + val currentValueStr = get(param).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") + } else { + "(undefined)" + } + s"${param.name}: ${param.doc} $valueStr" + } + + /** + * Explains all params of this instance. + * @see [[explainParam()]] */ - def explainParams(): String = params.mkString("\n") + def explainParams(): String = { + params.map(explainParam).mkString("\n") + } /** Checks whether a param is explicitly set. */ final def isSet(param: Param[_]): Boolean = { @@ -392,7 +438,6 @@ trait Params extends Identifiable with Serializable { * @param value the default value */ protected final def setDefault[T](param: Param[T], value: T): this.type = { - shouldOwn(param) defaultParamMap.put(param, value) this } @@ -430,13 +475,13 @@ trait Params extends Identifiable with Serializable { } /** - * Creates a copy of this instance with a randomly generated uid and some extra params. - * The default implementation calls the default constructor to create a new instance, then - * copies the embedded and extra parameters over and returns the new instance. + * Creates a copy of this instance with the same UID and some extra params. + * The default implementation tries to create a new instance with the same UID. + * Then it copies the embedded and extra parameters over and returns the new instance. * Subclasses should override this method if the default approach is not sufficient. */ def copy(extra: ParamMap): Params = { - val that = this.getClass.newInstance() + val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) copyValues(that, extra) that } @@ -465,7 +510,7 @@ trait Params extends Identifiable with Serializable { /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { - require(param.parent.eq(this), s"Param $param does not belong to $this.") + require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.") } /** @@ -581,7 +626,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) override def toString: String = { map.toSeq.sortBy(_._1.name).map { case (param, value) => - s"\t${param.parent.uid}-${param.name}: $value" + s"\t${param.parent}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d7cbffc3be26f..45c57b50da70f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -171,7 +172,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Model fitted by ALS. */ class ALSModel private[ml] ( - override val parent: ALS, + override val uid: String, k: Int, userFactors: RDD[(Int, Array[Float])], itemFactors: RDD[(Int, Array[Float])]) @@ -235,10 +236,12 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ -class ALS extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + def this() = this(Identifiable.randomUID("als")) + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) @@ -303,7 +306,8 @@ class ALS extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - copyValues(new ALSModel(this, $(rank), userFactors, itemFactors)) + val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) + copyValues(model) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index f8f0b161a4812..e67df21b2e4ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} @@ -38,10 +38,12 @@ import org.apache.spark.sql.DataFrame * It supports both continuous and categorical features. */ @AlphaComponent -final class DecisionTreeRegressor +final class DecisionTreeRegressor(override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { + def this() = this(Identifiable.randomUID("dtr")) + // Override parameter setters from parent trait for Java API compatibility. override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) @@ -91,7 +93,7 @@ object DecisionTreeRegressor { */ @AlphaComponent final class DecisionTreeRegressionModel private[ml] ( - override val parent: DecisionTreeRegressor, + override val uid: String, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -104,7 +106,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) } override def toString: String = { @@ -128,6 +130,7 @@ private[ml] object DecisionTreeRegressionModel { s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeRegressionModel(parent, rootNode) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") + new DecisionTreeRegressionModel(uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 461905c12701a..4249ff5c1ebc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} @@ -42,10 +42,12 @@ import org.apache.spark.sql.DataFrame * It supports both continuous and categorical features. */ @AlphaComponent -final class GBTRegressor +final class GBTRegressor(override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { + def this() = this(Identifiable.randomUID("gbtr")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: @@ -149,7 +151,7 @@ object GBTRegressor { */ @AlphaComponent final class GBTRegressionModel( - override val parent: GBTRegressor, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTRegressionModel] @@ -173,7 +175,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) } override def toString: String = { @@ -199,6 +201,7 @@ private[ml] object GBTRegressionModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTRegressionModel(parent, newTrees, oldModel.treeWeights) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 36c242bb5f2a7..3ebb78f79201a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV, norm => brzNorm} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, - OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint @@ -59,9 +59,12 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L2 + L1 (elastic net) */ @AlphaComponent -class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] +class LinearRegression(override val uid: String) + extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { + def this() = this(Identifiable.randomUID("linReg")) + /** * Set the regularization parameter. * Default is 0.0. @@ -128,7 +131,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean) + return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean) } val featuresMean = summarizer.mean.toArray @@ -182,7 +185,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. - new LinearRegressionModel(this, weights.compressed, intercept) + copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } } @@ -193,7 +196,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress */ @AlphaComponent class LinearRegressionModel private[ml] ( - override val parent: LinearRegression, + override val uid: String, val weights: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] @@ -204,7 +207,7 @@ class LinearRegressionModel private[ml] ( } override def copy(extra: ParamMap): LinearRegressionModel = { - copyValues(new LinearRegressionModel(parent, weights, intercept), extra) + copyValues(new LinearRegressionModel(uid, weights, intercept), extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index dbc628927433d..82437aa8de294 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} @@ -37,10 +37,12 @@ import org.apache.spark.sql.DataFrame * It supports both continuous and categorical features. */ @AlphaComponent -final class RandomForestRegressor +final class RandomForestRegressor(override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { + def this() = this(Identifiable.randomUID("rfr")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: @@ -105,7 +107,7 @@ object RandomForestRegressor { */ @AlphaComponent final class RandomForestRegressionModel private[ml] ( - override val parent: RandomForestRegressor, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel]) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -128,7 +130,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(parent, _trees), extra) + copyValues(new RandomForestRegressionModel(uid, _trees), extra) } override def toString: String = { @@ -154,6 +156,6 @@ private[ml] object RandomForestRegressionModel { // parent, fittingParamMap for each tree is null since there are no good ways to set these. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent, newTrees) + new RandomForestRegressionModel(parent.uid, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ac0d1fed84b2e..5c6ff2dda3604 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -23,6 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -81,7 +82,10 @@ private[ml] trait CrossValidatorParams extends Params { * K-fold cross validation. */ @AlphaComponent -class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { +class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] + with CrossValidatorParams with Logging { + + def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS @@ -136,7 +140,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(this, bestModel)) + copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -150,7 +154,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP */ @AlphaComponent class CrossValidatorModel private[ml] ( - override val parent: CrossValidator, + override val uid: String, val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index 8a56748ab0a02..146697680092c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,15 +19,24 @@ package org.apache.spark.ml.util import java.util.UUID + /** - * Object with a unique id. + * Trait for an object with an immutable unique ID that identifies itself and its derivatives. */ -private[ml] trait Identifiable extends Serializable { +trait Identifiable { + + /** + * An immutable unique ID for the object and its derivatives. + */ + val uid: String +} + +object Identifiable { /** - * A unique id for the object. The default implementation concatenates the class name, "_", and 8 - * random hex chars. + * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. */ - private[ml] val uid: String = - this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) + def randomUID(prefix: String): String = { + prefix + "_" + UUID.randomUUID().toString.takeRight(12) + } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e7189a2b1d53..f75e024a713ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -84,7 +84,7 @@ public void logisticRegressionWithSetters() { .setThreshold(0.6) .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - LogisticRegression parent = model.parent(); + LogisticRegression parent = (LogisticRegression) model.parent(); assert(parent.getMaxIter() == 10); assert(parent.getRegParam() == 1.0); assert(parent.getThreshold() == 0.6); @@ -110,7 +110,7 @@ public void logisticRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); - LogisticRegression parent2 = model2.parent(); + LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); assert(parent2.getThreshold() == 0.4); diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 8abe575610d19..3a41890b92d63 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -21,43 +21,65 @@ import com.google.common.collect.Lists; +import org.apache.spark.ml.util.Identifiable$; + /** * A subclass of Params for testing. */ public class JavaTestParams extends JavaParams { - public IntParam myIntParam; + public JavaTestParams() { + this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams"); + init(); + } + + public JavaTestParams(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_; + + @Override + public String uid() { + return uid_; + } - public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); } + private IntParam myIntParam_; + public IntParam myIntParam() { return myIntParam_; } + + public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam, value); return this; + set(myIntParam_, value); return this; } - public DoubleParam myDoubleParam; + private DoubleParam myDoubleParam_; + public DoubleParam myDoubleParam() { return myDoubleParam_; } - public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); } + public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam, value); return this; + set(myDoubleParam_, value); return this; } - public Param myStringParam; + private Param myStringParam_; + public Param myStringParam() { return myStringParam_; } - public String getMyStringParam() { return (String)getOrDefault(myStringParam); } + public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam, value); return this; + set(myStringParam_, value); return this; } - public JavaTestParams() { - myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); - myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param", + private void init() { + myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); + myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); List validStrings = Lists.newArrayList("a", "b"); - myStringParam = new Param(this, "myStringParam", "this is a string param", + myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - setDefault(myIntParam, 1); - setDefault(myDoubleParam, 0.5); + setDefault(myIntParam_, 1); + setDefault(myDoubleParam_, 0.5); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index a82b86d560b6e..d591a456864e4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -77,14 +77,14 @@ public void linearRegressionWithSetters() { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - LinearRegression parent = model.parent(); + LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); assertEquals(1.0, parent.getRegParam(), 0.0); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - LinearRegression parent2 = model2.parent(); + LinearRegression parent2 = (LinearRegression) model2.parent(); assertEquals(5, parent2.getMaxIter()); assertEquals(0.1, parent2.getRegParam(), 0.0); assertEquals("thePred", model2.getPredictionCol()); diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala new file mode 100644 index 0000000000000..67c262d0f9d8d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.scalatest.FunSuite + +class IdentifiableSuite extends FunSuite { + + import IdentifiableSuite.Test + + test("Identifiable") { + val test0 = new Test("test_0") + assert(test0.uid === "test_0") + + val test1 = new Test + assert(test1.uid.startsWith("test_")) + } +} + +object IdentifiableSuite { + + class Test(override val uid: String) extends Identifiable { + def this() = this(Identifiable.randomUID("test")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 03af4ecd7a7e0..3fdc66be8a314 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -268,7 +268,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( - oldTree, newTree.parent, categoricalFeatures) + oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 16c758b82c7cd..ea86867f1161a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -130,7 +130,7 @@ private object GBTClassifierSuite { val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 4df8016009171..43765241a20b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.ml.classification import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} - class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @@ -37,8 +36,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 4)) + dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) /** * Here is the instruction describing how to export the test data into CSV format @@ -60,31 +58,30 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42), 4)) + sqlContext.createDataFrame( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) } } test("logistic regression: default params") { val lr = new LogisticRegression - assert(lr.getLabelCol == "label") - assert(lr.getFeaturesCol == "features") - assert(lr.getPredictionCol == "prediction") - assert(lr.getRawPredictionCol == "rawPrediction") - assert(lr.getProbabilityCol == "probability") - assert(lr.getFitIntercept == true) + assert(lr.getLabelCol === "label") + assert(lr.getFeaturesCol === "features") + assert(lr.getPredictionCol === "prediction") + assert(lr.getRawPredictionCol === "rawPrediction") + assert(lr.getProbabilityCol === "probability") + assert(lr.getFitIntercept) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") .collect() assert(model.getThreshold === 0.5) - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") - assert(model.getRawPredictionCol == "rawPrediction") - assert(model.getProbabilityCol == "probability") + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) } @@ -103,7 +100,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setThreshold(0.6) .setProbabilityCol("myProbability") val model = lr.fit(dataset) - val parent = model.parent + val parent = model.parent.asInstanceOf[LogisticRegression] assert(parent.getMaxIter === 10) assert(parent.getRegParam === 1.0) assert(parent.getThreshold === 0.6) @@ -129,12 +126,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { // Call fit() with new params, and check as many params as we can. val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, lr.probabilityCol -> "theProb") - val parent2 = model2.parent + val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) assert(parent2.getRegParam === 0.1) assert(parent2.getThreshold === 0.4) assert(model2.getThreshold === 0.4) - assert(model2.getProbabilityCol == "theProb") + assert(model2.getProbabilityCol === "theProb") } test("logistic regression: Predictor, Classifier methods") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index e65ffae918ca9..990cfb08af83b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -57,7 +57,7 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() - ova.setClassifier(new LogisticRegression) + .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) @@ -97,7 +97,9 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { } } -private class MockLogisticRegression extends LogisticRegression { +private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { + + def this() = this("mockLogReg") setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index c41def9330504..08f86fa45bc1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -160,7 +160,7 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 6056e7d3f6ff8..b96874f3a8821 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -23,21 +23,22 @@ class ParamsSuite extends FunSuite { test("param") { val solver = new TestParams() + val uid = solver.uid import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations (>= 0)") - assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)") + assert(maxIter.parent === uid) + assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) assert(maxIter.isValid(0)) assert(maxIter.isValid(1)) solver.setMaxIter(5) - assert(maxIter.toString === + assert(solver.explainParam(maxIter) === "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") - assert(inputCol.toString === "inputCol: input column name (undefined)") + assert(inputCol.toString === s"${uid}__inputCol") intercept[IllegalArgumentException] { solver.setMaxIter(-1) @@ -118,7 +119,10 @@ class ParamsSuite extends FunSuite { assert(!solver.isDefined(inputCol)) intercept[NoSuchElementException](solver.getInputCol) - assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.explainParam(maxIter) === + "maxIter: max number of iterations (>= 0) (default: 10, current: 100)") + assert(solver.explainParams() === + Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) @@ -148,7 +152,7 @@ class ParamsSuite extends FunSuite { assert(!solver.isSet(maxIter)) val copied = solver.copy(ParamMap(solver.maxIter -> 50)) - assert(copied.uid !== solver.uid) + assert(copied.uid === solver.uid) assert(copied.getInputCol === solver.getInputCol) assert(copied.getMaxIter === 50) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index dc16073640407..a9e78366ad98f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -18,9 +18,12 @@ package org.apache.spark.ml.param import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { + + def this() = this(Identifiable.randomUID("testParams")) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 5aa81b44ddaf9..1196a772dfdd4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -85,7 +85,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( - oldTree, newTree.parent, categoricalFeatures) + oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 25b36ab08b67c..40e7e3273e965 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -130,7 +130,8 @@ private object GBTRegressorSuite { val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures) + val oldModelAsNew = GBTRegressionModel.fromOld( + oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 45f09f4fdab81..3efffbb763b78 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -116,7 +116,7 @@ private object RandomForestRegressorSuite extends FunSuite { val newModel = rf.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } From c1080b6fddb22d84694da2453e46a03fbc041576 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 14 May 2015 01:26:08 -0700 Subject: [PATCH 036/109] [SPARK-7568] [ML] ml.LogisticRegression doesn't output the right prediction The difference is because we previously don't fit the intercept in Spark 1.3. Here, we change the input `String` so that the probability of instance 6 can be classified as `1.0` without any ambiguity. with lambda = 0.001 in current LOR implementation, the prediction is ``` (4, spark i j k) --> prob=[0.1596407738787411,0.8403592261212589], prediction=1.0 (5, l m n) --> prob=[0.8378325685476612,0.16216743145233883], prediction=0.0 (6, spark hadoop spark) --> prob=[0.0692663313297627,0.9307336686702373], prediction=1.0 (7, apache hadoop) --> prob=[0.9821575333444208,0.01784246665557917], prediction=0.0 ``` and the training accuracy is ``` (0, a b c d e spark) --> prob=[0.0021342419881406746,0.9978657580118594], prediction=1.0 (1, b d) --> prob=[0.9959176174854043,0.004082382514595685], prediction=0.0 (2, spark f g h) --> prob=[0.0014541569986711233,0.9985458430013289], prediction=1.0 (3, hadoop mapreduce) --> prob=[0.9982978367343561,0.0017021632656438518], prediction=0.0 ``` Author: DB Tsai Closes #6109 from dbtsai/lor-example and squashes the following commits: ac63ce4 [DB Tsai] first commit --- .../examples/ml/JavaSimpleTextClassificationPipeline.java | 4 ++-- .../src/main/python/ml/simple_text_classification_pipeline.py | 4 ++-- .../spark/examples/ml/SimpleTextClassificationPipeline.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index ef1ec103a879f..54738813d0016 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setOutputCol("features"); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01); + .setRegParam(0.001); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); @@ -77,7 +77,7 @@ public static void main(String[] args) { List localTest = Lists.newArrayList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), + new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index fab21f003b233..b4f06bf888746 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -48,7 +48,7 @@ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") - lr = LogisticRegression(maxIter=10, regParam=0.01) + lr = LogisticRegression(maxIter=10, regParam=0.001) pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. @@ -58,7 +58,7 @@ Document = Row("id", "text") test = sc.parallelize([(4, "spark i j k"), (5, "l m n"), - (6, "mapreduce spark"), + (6, "spark hadoop spark"), (7, "apache hadoop")]) \ .map(lambda x: Document(*x)).toDF() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 6772efd2c581c..1324b066c30c3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -64,7 +64,7 @@ object SimpleTextClassificationPipeline { .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01) + .setRegParam(0.001) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) @@ -75,7 +75,7 @@ object SimpleTextClassificationPipeline { val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), - Document(6L, "mapreduce spark"), + Document(6L, "spark hadoop spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. From 7fb715de6d90c3eb756935440f75b1de674f8ece Mon Sep 17 00:00:00 2001 From: FavioVazquez Date: Thu, 14 May 2015 15:22:58 +0100 Subject: [PATCH 037/109] [SPARK-7249] Updated Hadoop dependencies due to inconsistency in the versions Updated Hadoop dependencies due to inconsistency in the versions. Now the global properties are the ones used by the hadoop-2.2 profile, and the profile was set to empty but kept for backwards compatibility reasons. Changes proposed by vanzin resulting from previous pull-request https://github.com/apache/spark/pull/5783 that did not fixed the problem correctly. Please let me know if this is the correct way of doing this, the comments of vanzin are in the pull-request mentioned. Author: FavioVazquez Closes #5786 from FavioVazquez/update-hadoop-dependencies and squashes the following commits: 11670e5 [FavioVazquez] - Added missing instance of -Phadoop-2.2 in create-release.sh 379f50d [FavioVazquez] - Added instances of -Phadoop-2.2 in create-release.sh, run-tests, scalastyle and building-spark.md - Reconstructed docs to not ask users to rely on default behavior 3f9249d [FavioVazquez] Merge branch 'master' of https://github.com/apache/spark into update-hadoop-dependencies 31bdafa [FavioVazquez] - Added missing instances in -Phadoop-1 in create-release.sh, run-tests and in the building-spark documentation cbb93e8 [FavioVazquez] - Added comment related to SPARK-3710 about hadoop-yarn-server-tests in Hadoop 2.2 that fails to pull some needed dependencies 83dc332 [FavioVazquez] - Cleaned up the main POM concerning the yarn profile - Erased hadoop-2.2 profile from yarn/pom.xml and its content was integrated into yarn/pom.xml 93f7624 [FavioVazquez] - Deleted unnecessary comments and tag on the YARN profile in the main POM 668d126 [FavioVazquez] - Moved and sections of the hadoop-2.2 profile in the YARN POM to the YARN profile in the root POM - Erased unnecessary hadoop-2.2 profile from the YARN POM fda6a51 [FavioVazquez] - Updated hadoop1 releases in create-release.sh due to changes in the default hadoop version set - Erased unnecessary instance of -Dyarn.version=2.2.0 in create-release.sh - Prettify comment in yarn/pom.xml 0470587 [FavioVazquez] - Erased unnecessary instance of -Phadoop-2.2 -Dhadoop.version=2.2.0 in create-release.sh - Updated how the releases are made in the create-release.sh no that the default hadoop version is the 2.2.0 - Erased unnecessary instance of -Phadoop-2.2 -Dhadoop.version=2.2.0 in scalastyle - Erased unnecessary instance of -Phadoop-2.2 -Dhadoop.version=2.2.0 in run-tests - Better example given in the hadoop-third-party-distributions.md now that the default hadoop version is 2.2.0 a650779 [FavioVazquez] - Default value of avro.mapred.classifier has been set to hadoop2 in pom.xml - Cleaned up hadoop-2.3 and 2.4 profiles due to change in the default set in avro.mapred.classifier in pom.xml 199f40b [FavioVazquez] - Erased unnecessary CDH5-specific note in docs/building-spark.md - Remove example of instance -Phadoop-2.2 -Dhadoop.version=2.2.0 in docs/building-spark.md - Enabled hadoop-2.2 profile when the Hadoop version is 2.2.0, which is now the default .Added comment in the yarn/pom.xml to specify that. 88a8b88 [FavioVazquez] - Simplified Hadoop profiles due to new setting of global properties in the pom.xml file - Added comment to specify that the hadoop-2.2 profile is now the default hadoop profile in the pom.xml file - Erased hadoop-2.2 from related hadoop profiles now that is a no-op in the make-distribution.sh file 70b8344 [FavioVazquez] - Fixed typo in the make-distribution.sh file and added hadoop-1 in the Related profiles 287fa2f [FavioVazquez] - Updated documentation about specifying the hadoop version in building-spark. Now is clear that Spark will build against Hadoop 2.2.0 by default. - Added Cloudera CDH 5.3.3 without MapReduce example in the building-spark doc. 1354292 [FavioVazquez] - Fixed hadoop-1 version to match jenkins build profile in hadoop1.0 tests and documentation 6b4bfaf [FavioVazquez] - Cleanup in hadoop-2.x profiles since they contained mostly redundant stuff. 7e9955d [FavioVazquez] - Updated Hadoop dependencies due to inconsistency in the versions. Now the global properties are the ones used by the hadoop-2.2 profile, and the profile was set to empty but kept for backwards compatibility reasons 660decc [FavioVazquez] - Updated Hadoop dependencies due to inconsistency in the versions. Now the global properties are the ones used by the hadoop-2.2 profile, and the profile was set to empty but kept for backwards compatibility reasons ec91ce3 [FavioVazquez] - Updated protobuf-java version of com.google.protobuf dependancy to fix blocking error when connecting to HDFS via the Hadoop Cloudera HDFS CDH5 (fix for 2.5.0-cdh5.3.3 version) --- dev/create-release/create-release.sh | 14 ++-- dev/run-tests | 6 +- dev/scalastyle | 4 +- docs/building-spark.md | 11 +-- docs/hadoop-third-party-distributions.md | 2 +- make-distribution.sh | 2 +- pom.xml | 33 ++++---- yarn/pom.xml | 97 +++++++++++------------- 8 files changed, 79 insertions(+), 90 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 3dbb35f7054a2..af4f00054997c 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,14 +118,14 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + build/mvn -DskipTests -Pyarn -Phive \ + -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh - build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + build/mvn -DskipTests -Pyarn -Phive \ + -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.10.sh @@ -228,9 +228,9 @@ if [[ ! "$@" =~ --skip-package ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" "3030" & - make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop1" "-Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" "3035" & diff --git a/dev/run-tests b/dev/run-tests index ef587a1a5988c..44d802782c4a4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -40,11 +40,11 @@ function handle_error () { { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4" + export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=1.0.4" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" + export SBT_MAVEN_PROFILES_ARGS="-Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi diff --git a/dev/scalastyle b/dev/scalastyle index 4e03f89ed5d5d..7f014c82f14c6 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -20,8 +20,8 @@ echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt echo -e "q\n" | build/sbt -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 test:scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt +echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt diff --git a/docs/building-spark.md b/docs/building-spark.md index 287fcd3c4034f..6e310ff424784 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -59,14 +59,14 @@ You can fix this by setting the `MAVEN_OPTS` variable as discussed before. # Specifying the Hadoop Version -Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 1.0.4 by default. Note that certain build profiles are required for particular Hadoop versions: +Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: - + @@ -77,10 +77,10 @@ For Apache Hadoop versions 1.x, Cloudera CDH "mr1" distributions, and other Hado {% highlight bash %} # Apache Hadoop 1.2.1 -mvn -Dhadoop.version=1.2.1 -DskipTests clean package +mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package # Cloudera CDH 4.2.0 with MapReduce v1 -mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -DskipTests clean package +mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package {% endhighlight %} You can enable the "yarn" profile and optionally set the "yarn.version" property if it is different from "hadoop.version". Spark only supports YARN versions 2.2.0 and later. @@ -88,8 +88,9 @@ You can enable the "yarn" profile and optionally set the "yarn.version" property Examples: {% highlight bash %} + # Apache Hadoop 2.2.X -mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package +mvn -Pyarn -Phadoop-2.2 -DskipTests clean package # Apache Hadoop 2.3.X mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index 96bd69ca3b33b..795dd82a6be06 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -14,7 +14,7 @@ property. For certain versions, you will need to specify additional profiles. Fo see the guide on [building with maven](building-spark.html#specifying-the-hadoop-version): mvn -Dhadoop.version=1.0.4 -DskipTests clean package - mvn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package + mvn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package The table below lists the corresponding `hadoop.version` code for each CDH/HDP release. Note that some Hadoop releases are binary compatible across client versions. This means the pre-built Spark diff --git a/make-distribution.sh b/make-distribution.sh index 1bfa9acb1fe6e..8d6e91d67593f 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) diff --git a/pom.xml b/pom.xml index 564a443466e5a..91d1d843c762a 100644 --- a/pom.xml +++ b/pom.xml @@ -122,9 +122,9 @@ 1.7.101.2.172.2.0 - 2.4.1 + 2.5.0${hadoop.version} - 0.98.7-hadoop1 + 0.98.7-hadoop2hbase1.4.03.4.5 @@ -143,7 +143,7 @@ 2.0.83.1.01.7.7 - + hadoop20.7.11.8.31.1.0 @@ -155,7 +155,7 @@ ${scala.version}org.scala-lang3.6.3 - 1.8.8 + 1.9.132.4.41.1.1.71.1.2 @@ -1644,26 +1644,27 @@ --> - hadoop-2.2 + hadoop-1 - 2.2.0 - 2.5.0 - 0.98.7-hadoop2 - hadoop2 - 1.9.13 + 1.0.4 + 2.4.1 + 0.98.7-hadoop1 + hadoop1 + 1.8.8 + + hadoop-2.2 + + + hadoop-2.3 2.3.0 - 2.5.0 0.9.3 - 0.98.7-hadoop2 3.1.1 - hadoop2 - 1.9.13 @@ -1671,12 +1672,8 @@ hadoop-2.4 2.4.0 - 2.5.0 0.9.3 - 0.98.7-hadoop2 3.1.1 - hadoop2 - 1.9.13 diff --git a/yarn/pom.xml b/yarn/pom.xml index 7c8c3613e7a05..00d219f836708 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -30,6 +30,7 @@ Spark Project YARN yarn + 1.9 @@ -85,7 +86,12 @@ jetty-servlet - + + + org.apache.hadoop hadoop-yarn-server-tests @@ -97,59 +103,44 @@ mockito-all test + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + - - - - - hadoop-2.2 - - 1.9 - - - - org.mortbay.jetty - jetty - 6.1.26 - - - org.mortbay.jetty - servlet-api - - - test - - - com.sun.jersey - jersey-core - ${jersey.version} - test - - - com.sun.jersey - jersey-json - ${jersey.version} - test - - - stax - stax-api - - - - - com.sun.jersey - jersey-server - ${jersey.version} - test - - - - - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes From f2cd00be350fdba3acfbfdf155701182d1c404fd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 May 2015 10:25:18 -0700 Subject: [PATCH 038/109] [SQL][minor] rename apply for QueryPlanner A follow-up of https://github.com/apache/spark/pull/5624 Author: Wenchen Fan Closes #6142 from cloud-fan/tmp and squashes the following commits: 971a92b [Wenchen Fan] use plan instead of execute 24c5ffe [Wenchen Fan] rename apply --- .../org/apache/spark/sql/catalyst/planning/QueryPlanner.scala | 4 ++-- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 51b5699affed5..73a21884a4710 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -51,9 +51,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { * filled in automatically by the QueryPlanner using the other execution strategies that are * available. */ - protected def planLater(plan: LogicalPlan) = apply(plan).next() + protected def planLater(plan: LogicalPlan) = this.plan(plan).next() - def apply(plan: LogicalPlan): Iterator[PhysicalPlan] = { + def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // Obviously a lot to do here still... val iter = strategies.view.flatMap(_(plan)).toIterator assert(iter.hasNext, s"No plan for $plan") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 521f3dc821795..b33a700208014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1321,7 +1321,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) - planner(optimizedPlan).next() + planner.plan(optimizedPlan).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. From 5d7d4f887d509e6d037d8fc5247d2e5f8a4563c9 Mon Sep 17 00:00:00 2001 From: ksonj Date: Thu, 14 May 2015 15:10:58 -0700 Subject: [PATCH 039/109] [SPARK-7278] [PySpark] DateType should find datetime.datetime acceptable DateType should not be restricted to `datetime.date` but accept `datetime.datetime` objects as well. Could someone with a little more insight verify this? Author: ksonj Closes #6057 from ksonj/dates and squashes the following commits: 68a158e [ksonj] DateType should find datetime.datetime acceptable too --- python/pyspark/sql/_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index b96851a174d49..629c3a94513b8 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -930,7 +930,7 @@ def _infer_schema_type(obj, dataType): DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), - DateType: (datetime.date,), + DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), From 11a1a135d1fe892cd48a9116acc7554846aed84c Mon Sep 17 00:00:00 2001 From: tedyu Date: Thu, 14 May 2015 15:26:35 -0700 Subject: [PATCH 040/109] Make SPARK prefix a variable Author: tedyu Closes #6153 from ted-yu/master and squashes the following commits: 4e0bac5 [tedyu] Use JIRA_PROJECT_NAME as variable name ab982aa [tedyu] Make SPARK prefix a variable --- dev/github_jira_sync.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index 8051080117062..ff1e39664ee04 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -33,6 +33,7 @@ # User facing configs GITHUB_API_BASE = os.environ.get("GITHUB_API_BASE", "https://api.github.com/repos/apache/spark") +JIRA_PROJECT_NAME = os.environ.get("JIRA_PROJECT_NAME", "SPARK") JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "apachespark") JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "XXX") @@ -68,7 +69,7 @@ def get_jira_prs(): page_json = get_json(page) for pull in page_json: - jiras = re.findall("SPARK-[0-9]{4,5}", pull['title']) + jiras = re.findall(JIRA_PROJECT_NAME + "-[0-9]{4,5}", pull['title']) for jira in jiras: result = result + [(jira, pull)] From 93dbb3ad83fd60444a38c3dc87a2053c667123af Mon Sep 17 00:00:00 2001 From: Rex Xiong Date: Thu, 14 May 2015 16:55:31 -0700 Subject: [PATCH 041/109] [SPARK-7598] [DEPLOY] Add aliveWorkers metrics in Master In Spark Standalone setup, when some workers are DEAD, they will stay in master worker list for a while. master.workers metrics for master is only showing the total number of workers, we need to monitor how many real ALIVE workers are there to ensure the cluster is healthy. Author: Rex Xiong Closes #6117 from twilightgod/add-aliveWorker-metrics and squashes the following commits: 6be69a5 [Rex Xiong] Fix comment for aliveWorkers metrics a882f39 [Rex Xiong] Fix style for aliveWorkers metrics 38ce955 [Rex Xiong] Add aliveWorkers metrics in Master --- .../scala/org/apache/spark/deploy/master/MasterSource.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 9c3f79f1244b7..66a9ff38678c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -30,6 +30,11 @@ private[spark] class MasterSource(val master: Master) extends Source { override def getValue: Int = master.workers.size }) + // Gauge for alive worker numbers in cluster + metricRegistry.register(MetricRegistry.name("aliveWorkers"), new Gauge[Int]{ + override def getValue: Int = master.workers.filter(_.state == WorkerState.ALIVE).size + }) + // Gauge for application numbers in cluster metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] { override def getValue: Int = master.apps.size From 57ed16cf9372c109e84bd51b728f2c82940949a7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 14 May 2015 16:56:32 -0700 Subject: [PATCH 042/109] [SPARK-7643] [UI] use the correct size in RDDPage for storage info and partitions `dataDistribution` and `partitions` are `Option[Seq[_]]`. andrewor14 squito Author: Xiangrui Meng Closes #6157 from mengxr/SPARK-7643 and squashes the following commits: 99fe8a4 [Xiangrui Meng] use the correct size in RDDPage for storage info and partitions --- .../main/scala/org/apache/spark/ui/storage/RDDPage.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 05f94a7507f4f..fbce917a0824d 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -77,14 +77,17 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
-

Data Distribution on {rddStorageInfo.dataDistribution.size} Executors

+

+ Data Distribution on {rddStorageInfo.dataDistribution.map(_.size).getOrElse(0)} + Executors +

{workerTable}
-

{rddStorageInfo.partitions.size} Partitions

+

{rddStorageInfo.partitions.map(_.size).getOrElse(0)} Partitions

{blockTable}
; From 0a317c124c3a43089cdb8f079345c8f2842238cd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 14 May 2015 16:57:33 -0700 Subject: [PATCH 043/109] [SPARK-7649] [STREAMING] [WEBUI] Use window.localStorage to store the status rather than the url Use window.localStorage to store the status rather than the url so that the url won't be changed. cc tdas Author: zsxwing Closes #6158 from zsxwing/SPARK-7649 and squashes the following commits: 3c56fef [zsxwing] Use window.localStorage to store the status rather than the url --- .../apache/spark/ui/static/streaming-page.js | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js index 22b186873e990..0fac658d57842 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js +++ b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js @@ -252,28 +252,16 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { } $(function() { - function getParameterFromURL(param) - { - var parameters = window.location.search.substring(1); // Remove "?" - var keyValues = parameters.split('&'); - for (var i = 0; i < keyValues.length; i++) - { - var paramKeyValue = keyValues[i].split('='); - if (paramKeyValue[0] == param) - { - return paramKeyValue[1]; - } - } - } - - var status = getParameterFromURL("show-streams-detail") == "true"; + var status = window.localStorage && window.localStorage.getItem("show-streams-detail") == "true"; $("span.expand-input-rate").click(function() { status = !status; $("#inputs-table").toggle('collapsed'); // Toggle the class of the arrow between open and closed $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); - window.history.pushState('', document.title, window.location.pathname + '?show-streams-detail=' + status); + if (window.localStorage) { + window.localStorage.setItem("show-streams-detail", "" + status); + } }); if (status) { From b208f998b5800bdba4ce6651f172c26a8d7d351b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 14 May 2015 16:58:36 -0700 Subject: [PATCH 044/109] [SPARK-7645] [STREAMING] [WEBUI] Show milliseconds in the UI if the batch interval < 1 second I also updated the summary of the Streaming page. ![screen shot 2015-05-14 at 11 52 59 am](https://cloud.githubusercontent.com/assets/1000778/7640103/13cdf68e-fa36-11e4-84ec-e2a3954f4319.png) ![screen shot 2015-05-14 at 12 39 33 pm](https://cloud.githubusercontent.com/assets/1000778/7640151/4cc066ac-fa36-11e4-8494-2821d6a6f17c.png) Author: zsxwing Closes #6154 from zsxwing/SPARK-7645 and squashes the following commits: 5db6ca1 [zsxwing] Add UIUtils.formatBatchTime e4802df [zsxwing] Show milliseconds in the UI if the batch interval < 1 second --- .../apache/spark/ui/static/streaming-page.js | 11 +++- .../spark/streaming/ui/AllBatchesTable.scala | 14 +++-- .../apache/spark/streaming/ui/BatchPage.scala | 5 +- .../spark/streaming/ui/StreamingPage.scala | 10 ++-- .../apache/spark/streaming/ui/UIUtils.scala | 55 ++++++++++++++++++- .../spark/streaming/ui/UIUtilsSuite.scala | 11 ++++ 6 files changed, 94 insertions(+), 12 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js index 0fac658d57842..0ee6752b29e9a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js +++ b/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js @@ -98,7 +98,16 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { var x = d3.scale.linear().domain([minX, maxX]).range([0, width]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); - var xAxis = d3.svg.axis().scale(x).orient("bottom").tickFormat(function(d) { return timeFormat[d]; }); + var xAxis = d3.svg.axis().scale(x).orient("bottom").tickFormat(function(d) { + var formattedDate = timeFormat[d]; + var dotIndex = formattedDate.indexOf('.'); + if (dotIndex >= 0) { + // Remove milliseconds + return formattedDate.substring(0, dotIndex); + } else { + return formattedDate; + } + }); var formatYValue = d3.format(",.2f"); var yAxis = d3.svg.axis().scale(y).orient("left").ticks(5).tickFormat(formatYValue); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 3619e129ad9cf..00cc47d6a3ca5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -17,11 +17,14 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.Date + import scala.xml.Node import org.apache.spark.ui.{UIUtils => SparkUIUtils} -private[ui] abstract class BatchTableBase(tableId: String) { +private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) { protected def columns: Seq[Node] = {
@@ -35,7 +38,7 @@ private[ui] abstract class BatchTableBase(tableId: String) { protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds - val formattedBatchTime = SparkUIUtils.formatDate(batch.batchTime.milliseconds) + val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) val eventCount = batch.numRecords val schedulingDelay = batch.schedulingDelay val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") @@ -79,7 +82,8 @@ private[ui] abstract class BatchTableBase(tableId: String) { private[ui] class ActiveBatchTable( runningBatches: Seq[BatchUIData], - waitingBatches: Seq[BatchUIData]) extends BatchTableBase("active-batches-table") { + waitingBatches: Seq[BatchUIData], + batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { override protected def columns: Seq[Node] = super.columns ++ @@ -99,8 +103,8 @@ private[ui] class ActiveBatchTable( } } -private[ui] class CompletedBatchTable(batches: Seq[BatchUIData]) - extends BatchTableBase("completed-batches-table") { +private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) + extends BatchTableBase("completed-batches-table", batchInterval) { override protected def columns: Seq[Node] = super.columns ++ - val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, + val accumulableTable = UIUtils.listingTable( + accumulableHeaders, + accumulableRow, accumulables.values.toSeq) val taskHeadersAndCssClasses: Seq[(String, String)] = @@ -232,10 +264,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val unzipped = taskHeadersAndCssClasses.unzip + val currentTime = System.currentTimeMillis() val taskTable = UIUtils.listingTable( unzipped._1, - taskRow(hasAccumulators, stageData.hasInput, stageData.hasOutput, - stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled), + taskRow( + hasAccumulators, + stageData.hasInput, + stageData.hasOutput, + stageData.hasShuffleRead, + stageData.hasShuffleWrite, + stageData.hasBytesSpilled, + currentTime), tasks, headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics @@ -460,25 +499,192 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { dagViz ++ maybeExpandDagViz ++ showAdditionalMetrics ++ + makeTimeline(stageData.taskData.values.toSeq, currentTime) ++

Summary Metrics for {numCompleted} Completed Tasks

++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++

Aggregated Metrics by Executor

++ executorTable.toNodeSeq ++ maybeAccumulableTable ++

Tasks

++ taskTable - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } + def makeTimeline(tasks: Seq[TaskUIData], currentTime: Long): Seq[Node] = { + val executorsSet = new HashSet[(String, String)] + var minLaunchTime = Long.MaxValue + var maxFinishTime = Long.MinValue + + val executorsArrayStr = + tasks.sortBy(-_.taskInfo.launchTime).take(MAX_TIMELINE_TASKS).map { taskUIData => + val taskInfo = taskUIData.taskInfo + val executorId = taskInfo.executorId + val host = taskInfo.host + executorsSet += ((executorId, host)) + + val classNameByStatus = { + if (taskInfo.successful) { + "succeeded" + } else if (taskInfo.failed) { + "failed" + } else if (taskInfo.running) { + "running" + } + } + + val launchTime = taskInfo.launchTime + val finishTime = if (!taskInfo.running) taskInfo.finishTime else currentTime + val totalExecutionTime = finishTime - launchTime + minLaunchTime = launchTime.min(minLaunchTime) + maxFinishTime = launchTime.max(maxFinishTime) + + def toProportion(time: Long) = (time.toDouble / totalExecutionTime * 100).toLong + + val metricsOpt = taskUIData.taskMetrics + val shuffleReadTime = + metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L) + val shuffleReadTimeProportion = toProportion(shuffleReadTime) + val shuffleWriteTime = + (metricsOpt.flatMap(_.shuffleWriteMetrics + .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong + val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) + val executorComputingTime = metricsOpt.map(_.executorRunTime).getOrElse(0L) - + shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = toProportion(executorComputingTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) + val serializationTimeProportion = toProportion(serializationTime) + val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) + val deserializationTimeProportion = toProportion(deserializationTime) + val gettingResultTime = getGettingResultTime(taskUIData.taskInfo) + val gettingResultTimeProportion = toProportion(gettingResultTime) + val schedulerDelay = totalExecutionTime - + (executorComputingTime + shuffleReadTime + shuffleWriteTime + + serializationTime + deserializationTime + gettingResultTime) + val schedulerDelayProportion = + (100 - executorComputingTimeProportion - shuffleReadTimeProportion - + shuffleWriteTimeProportion - serializationTimeProportion - + deserializationTimeProportion - gettingResultTimeProportion) + + val schedulerDelayProportionPos = 0 + val deserializationTimeProportionPos = + schedulerDelayProportionPos + schedulerDelayProportion + val shuffleReadTimeProportionPos = + deserializationTimeProportionPos + deserializationTimeProportion + val executorRuntimeProportionPos = + shuffleReadTimeProportionPos + shuffleReadTimeProportion + val shuffleWriteTimeProportionPos = + executorRuntimeProportionPos + executorComputingTimeProportion + val serializationTimeProportionPos = + shuffleWriteTimeProportionPos + shuffleWriteTimeProportion + val gettingResultTimeProportionPos = + serializationTimeProportionPos + serializationTimeProportion + + val index = taskInfo.index + val attempt = taskInfo.attempt + val timelineObject = + s""" + { + 'className': 'task task-assignment-timeline-object $classNameByStatus', + 'group': '$executorId', + 'content': '
' + + 'Status: ${taskInfo.status}
' + + 'Launch Time: ${UIUtils.formatDate(new Date(launchTime))}' + + '${ + if (!taskInfo.running) { + s"""
Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" + } else { + "" + } + }' + + '
Scheduler Delay: $schedulerDelay ms' + + '
Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)}' + + '
Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)}' + + '
Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)}' + + '
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)}' + + '
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)}' + + '
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}">' + + '' + + '' + + '' + + '' + + '' + + '' + + '' + + '', + 'start': new Date($launchTime), + 'end': new Date($finishTime) + } + """ + timelineObject + }.mkString("[", ",", "]") + + val groupArrayStr = executorsSet.map { + case (executorId, host) => + s""" + { + 'id': '$executorId', + 'content': '$executorId / $host', + } + """ + }.mkString("[", ",", "]") + + val maxZoom = maxFinishTime - minLaunchTime + + + Event Timeline + ++ + ++ + + } + def taskRow( hasAccumulators: Boolean, hasInput: Boolean, hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = { + hasBytesSpilled: Boolean, + currentTime: Long)(taskData: TaskUIData): Seq[Node] = { taskData match { case TaskUIData(info, metrics, errorMessage) => - val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) + val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") @@ -542,7 +748,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") -
+ +
Hadoop versionProfile required
1.x to 2.1.x(none)
1.x to 2.1.xhadoop-1
2.2.xhadoop-2.2
2.3.xhadoop-2.3
2.4.xhadoop-2.4
Batch TimeStatusTotal Delay diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 831f60e870f74..f75067669abe5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.Date import javax.servlet.http.HttpServletRequest import scala.xml.{NodeSeq, Node, Text} @@ -288,7 +290,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } - val formattedBatchTime = SparkUIUtils.formatDate(batchTime.milliseconds) + val formattedBatchTime = + UIUtils.formatBatchTime(batchTime.milliseconds, streamingListener.batchDuration) val batchUIData = streamingListener.getBatchUIData(batchTime).getOrElse { throw new IllegalArgumentException(s"Batch $formattedBatchTime does not exist") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index efce8c58fb962..070564aa10633 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -186,6 +186,8 @@ private[ui] class StreamingPage(parent: StreamingTab) {SparkUIUtils.formatDate(startTime)} + ({listener.numTotalCompletedBatches} + completed batches, {listener.numTotalReceivedRecords} records)
} @@ -199,9 +201,9 @@ private[ui] class StreamingPage(parent: StreamingTab) * @param times all time values that will be used in the graphs. */ private def generateTimeMap(times: Seq[Long]): Seq[Node] = { - val dateFormat = new SimpleDateFormat("HH:mm:ss") val js = "var timeFormat = {};\n" + times.map { time => - val formattedTime = dateFormat.format(new Date(time)) + val formattedTime = + UIUtils.formatBatchTime(time, listener.batchDuration, showYYYYMMSS = false) s"timeFormat[$time] = '$formattedTime';" }.mkString("\n") @@ -472,14 +474,14 @@ private[ui] class StreamingPage(parent: StreamingTab) val activeBatchesContent = {

Active Batches ({runningBatches.size + waitingBatches.size})

++ - new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq } val completedBatchesContent = {

Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches})

++ - new CompletedBatchTable(completedBatches).toNodeSeq + new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq } activeBatchesContent ++ completedBatchesContent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index f153ee105a18e..86cfb1fa47370 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.TimeZone import java.util.concurrent.TimeUnit private[streaming] object UIUtils { @@ -62,7 +64,7 @@ private[streaming] object UIUtils { * Convert `milliseconds` to the specified `unit`. We cannot use `TimeUnit.convert` because it * will discard the fractional part. */ - def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { + def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { case TimeUnit.NANOSECONDS => milliseconds * 1000 * 1000 case TimeUnit.MICROSECONDS => milliseconds * 1000 case TimeUnit.MILLISECONDS => milliseconds @@ -71,4 +73,55 @@ private[streaming] object UIUtils { case TimeUnit.HOURS => milliseconds / 1000.0 / 60.0 / 60.0 case TimeUnit.DAYS => milliseconds / 1000.0 / 60.0 / 60.0 / 24.0 } + + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. + private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + } + + private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + } + + /** + * If `batchInterval` is less than 1 second, format `batchTime` with milliseconds. Otherwise, + * format `batchTime` without milliseconds. + * + * @param batchTime the batch time to be formatted + * @param batchInterval the batch interval + * @param showYYYYMMSS if showing the `yyyy/MM/dd` part. If it's false, the return value wll be + * only `HH:mm:ss` or `HH:mm:ss.SSS` depending on `batchInterval` + * @param timezone only for test + */ + def formatBatchTime( + batchTime: Long, + batchInterval: Long, + showYYYYMMSS: Boolean = true, + timezone: TimeZone = null): String = { + val oldTimezones = + (batchTimeFormat.get.getTimeZone, batchTimeFormatWithMilliseconds.get.getTimeZone) + if (timezone != null) { + batchTimeFormat.get.setTimeZone(timezone) + batchTimeFormatWithMilliseconds.get.setTimeZone(timezone) + } + try { + val formattedBatchTime = + if (batchInterval < 1000) { + batchTimeFormatWithMilliseconds.get.format(batchTime) + } else { + // If batchInterval >= 1 second, don't show milliseconds + batchTimeFormat.get.format(batchTime) + } + if (showYYYYMMSS) { + formattedBatchTime + } else { + formattedBatchTime.substring(formattedBatchTime.indexOf(' ') + 1) + } + } finally { + if (timezone != null) { + batchTimeFormat.get.setTimeZone(oldTimezones._1) + batchTimeFormatWithMilliseconds.get.setTimeZone(oldTimezones._2) + } + } + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index 6df1a63ab2e37..e9ab917ab845c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.ui +import java.util.TimeZone import java.util.concurrent.TimeUnit import org.scalatest.FunSuite @@ -64,4 +65,14 @@ class UIUtilsSuite extends FunSuite with Matchers{ val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit) convertedTime should be (expectedTime +- 1E-6) } + + test("formatBatchTime") { + val tzForTest = TimeZone.getTimeZone("America/Los_Angeles") + val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015 + assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest)) + assert("2015/05/14 14:04:40.452" === + UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest)) + assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest)) + assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest)) + } } From 723853edab18d28515af22097b76e4e6574b228e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 14 May 2015 18:13:58 -0700 Subject: [PATCH 045/109] [SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml Otherwise, users can only use `transform` on the models. brkyvz Author: Xiangrui Meng Closes #6156 from mengxr/SPARK-7647 and squashes the following commits: 1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel --- python/pyspark/ml/classification.py | 18 ++++++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 +++++++- 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 96d29058a3781..8c9a55e79abad 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeClassifierParams(object): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c3d20c3..2803864ff4a17 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a398642a..dda6c6aba3049 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -149,6 +149,12 @@ def __init__(self, java_model): def _java_obj(self): return self._java_model + def _call_java(self, name, *args): + m = getattr(self._java_model, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): From 48fc38f5844f6c12bf440f2990b6d7f1630fafac Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 14 May 2015 18:16:22 -0700 Subject: [PATCH 046/109] [SPARK-7619] [PYTHON] fix docstring signature Just realized that we need `\` at the end of the docstring. brkyvz Author: Xiangrui Meng Closes #6161 from mengxr/SPARK-7619 and squashes the following commits: e44495f [Xiangrui Meng] fix docstring signature --- python/docs/pyspark.ml.rst | 14 +++++------ python/pyspark/ml/classification.py | 39 ++++++++++++++--------------- python/pyspark/ml/feature.py | 8 +++--- python/pyspark/ml/recommendation.py | 8 +++--- python/pyspark/ml/regression.py | 38 +++++++++++++--------------- 5 files changed, 52 insertions(+), 55 deletions(-) diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 8379b8fc8a1e1..518b8e774dd5f 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -1,8 +1,8 @@ pyspark.ml package -===================== +================== ML Pipeline APIs --------------- +---------------- .. automodule:: pyspark.ml :members: @@ -10,7 +10,7 @@ ML Pipeline APIs :inherited-members: pyspark.ml.param module -------------------------- +----------------------- .. automodule:: pyspark.ml.param :members: @@ -34,7 +34,7 @@ pyspark.ml.classification module :inherited-members: pyspark.ml.recommendation module -------------------------- +-------------------------------- .. automodule:: pyspark.ml.recommendation :members: @@ -42,7 +42,7 @@ pyspark.ml.recommendation module :inherited-members: pyspark.ml.regression module -------------------------- +---------------------------- .. automodule:: pyspark.ml.regression :members: @@ -50,7 +50,7 @@ pyspark.ml.regression module :inherited-members: pyspark.ml.tuning module --------------------------------- +------------------------ .. automodule:: pyspark.ml.tuning :members: @@ -58,7 +58,7 @@ pyspark.ml.tuning module :inherited-members: pyspark.ml.evaluation module --------------------------------- +---------------------------- .. automodule:: pyspark.ml.evaluation :members: diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8c9a55e79abad..1411d3fd9c56e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -71,7 +71,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred threshold=0.5, probabilityCol="probability"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, probabilityCol="probability") """ super(LogisticRegression, self).__init__() @@ -96,8 +96,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, probabilityCol="probability"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, probabilityCol="probability") Sets params for logistic regression. """ @@ -220,7 +220,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") """ super(DecisionTreeClassifier, self).__init__() @@ -242,9 +242,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="gini"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini") + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") Sets params for the DecisionTreeClassifier. """ kwargs = self.setParams._input_kwargs @@ -320,9 +319,9 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", seed=42): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=42) """ super(RandomForestClassifier, self).__init__() @@ -355,9 +354,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, impurity="gini", numTrees=20, featureSubsetStrategy="auto"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto") Sets params for linear classification. """ @@ -471,10 +470,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1) + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="logistic", maxIter=20, stepSize=0.1) """ super(GBTClassifier, self).__init__() #: param for Loss function which GBT tries to minimize (case-insensitive). @@ -502,9 +501,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ lossType="logistic", maxIter=20, stepSize=0.1) Sets params for Gradient Boosted Tree Classification. """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 30e1fd4922d0a..58e22190c7c3c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -481,7 +481,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", inputCol=None, outputCol=None): """ - __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", + __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \ inputCol=None, outputCol=None) """ super(RegexTokenizer, self).__init__() @@ -496,7 +496,7 @@ def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+" def setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", inputCol=None, outputCol=None): """ - setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", + setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \ inputCol="input", outputCol="output") Sets params for this RegexTokenizer. """ @@ -869,7 +869,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, inputCol=None, outputCol=None): """ - __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, + __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \ seed=42, inputCol=None, outputCol=None) """ super(Word2Vec, self).__init__() @@ -889,7 +889,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, inputCol=None, outputCol=None): """ - setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, + setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, \ inputCol=None, outputCol=None) Sets params for this Word2Vec. """ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 4846b907e85ec..b2439cbd96522 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -92,8 +92,8 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ - __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, + __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, \ ratingCol="rating", nonnegative=false, checkpointInterval=10) """ super(ALS, self).__init__() @@ -118,8 +118,8 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ - setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, \ ratingCol="rating", nonnegative=False, checkpointInterval=10) Sets params for ALS. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2803864ff4a17..ef77e19327188 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -33,8 +33,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: - L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) @@ -191,7 +190,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") """ super(DecisionTreeRegressor, self).__init__() @@ -213,9 +212,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance") + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs @@ -286,10 +284,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", numTrees=20, featureSubsetStrategy="auto", seed=42): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=42) + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", numTrees=20, featureSubsetStrategy="auto", seed=42) """ super(RandomForestRegressor, self).__init__() #: param for Criterion used for information gain calculation (case-insensitive). @@ -321,9 +319,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, impurity="variance", numTrees=20, featureSubsetStrategy="auto"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \ impurity="variance", numTrees=20, featureSubsetStrategy="auto") Sets params for linear regression. """ @@ -432,10 +430,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", - maxIter=20, stepSize=0.1) + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="squared", maxIter=20, stepSize=0.1) """ super(GBTRegressor, self).__init__() #: param for Loss function which GBT tries to minimize (case-insensitive). @@ -463,9 +461,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ lossType="squared", maxIter=20, stepSize=0.1) Sets params for Gradient Boosted Tree Regression. """ From 6d0633e3ec9518278fcc7eba58549d4ad3d5813f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 19:49:44 -0700 Subject: [PATCH 047/109] [SPARK-7548] [SQL] Add explode function for DataFrames Add an `explode` function for dataframes and modify the analyzer so that single table generating functions can be present in a select clause along with other expressions. There are currently the following restrictions: - only top level TGFs are allowed (i.e. no `select(explode('list) + 1)`) - only one may be present in a single select to avoid potentially confusing implicit Cartesian products. TODO: - [ ] Python Author: Michael Armbrust Closes #6107 from marmbrus/explodeFunction and squashes the following commits: 7ee2c87 [Michael Armbrust] whitespace 6f80ba3 [Michael Armbrust] Update dataframe.py c176c89 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction 81b5da3 [Michael Armbrust] style d3faa05 [Michael Armbrust] fix self join case f9e1e3e [Michael Armbrust] fix python, add since 4f0d0a9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction e710fe4 [Michael Armbrust] add java and python 52ca0dc [Michael Armbrust] [SPARK-7548][SQL] Add explode function for dataframes. --- python/pyspark/sql/dataframe.py | 12 +- python/pyspark/sql/functions.py | 20 +++ python/pyspark/sql/tests.py | 15 +++ .../sql/catalyst/analysis/Analyzer.scala | 117 +++++++++++------- .../plans/logical/basicOperators.scala | 3 + .../sql/catalyst/analysis/AnalysisSuite.scala | 10 +- .../scala/org/apache/spark/sql/Column.scala | 27 +++- .../org/apache/spark/sql/DataFrame.scala | 5 +- .../org/apache/spark/sql/functions.scala | 5 + .../spark/sql/ColumnExpressionSuite.scala | 60 +++++++++ 10 files changed, 223 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 82cb1c2fdbf94..2ed95ac8e2505 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1511,13 +1511,19 @@ def inSet(self, *cols): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - def alias(self, alias): - """Return a alias for this column + def alias(self, *alias): + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ - return Column(getattr(self._jc, "as")(alias)) + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) @ignore_unicode_prefix def cast(self, dataType): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d91265ee0bec8..6cd6974b0e5bb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af61da..d37c5dbed7f6b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,21 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4baeeb5b58c2d..0b6e1d44b9c4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -73,7 +73,6 @@ class Analyzer( ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: - ImplicitGenerate :: ResolveFunctions :: ExtractWindowExpressions :: GlobalAggregates :: @@ -323,6 +322,11 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + case oldVersion @ Window(_, windowExpressions, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => @@ -521,66 +525,89 @@ class Analyzer( } /** - * When a SELECT clause has only a single expression and that expression is a - * [[catalyst.expressions.Generator Generator]] we convert the - * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an + * [[AnalysisException]] is throw. */ - object ImplicitGenerate extends Rule[LogicalPlan] { + object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, name)), child) => - Generate(g, join = false, outer = false, - qualifier = None, UnresolvedAttribute(name) :: Nil, child) - case Project(Seq(MultiAlias(g: Generator, names)), child) => - Generate(g, join = false, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), child) + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case g: Generate if g.resolved == false => + g.copy( + generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + + case p @ Project(projectList, child) => + // Holds the resolved generator, if one exists in the project list. + var resolvedGenerator: Generate = null + + val newProjectList = projectList.flatMap { + case AliasedGenerator(generator, names) if generator.childrenResolved => + if (resolvedGenerator != null) { + failAnalysis( + s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + + s"and ${generator.nodeName} found.") + } + + resolvedGenerator = + Generate( + generator, + join = projectList.size > 1, // Only join if there are other expressions in SELECT. + outer = false, + qualifier = None, + generatorOutput = makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } + + if (resolvedGenerator != null) { + Project(newProjectList, resolvedGenerator) + } else { + p + } } - } - /** - * Resolve the Generate, if the output names specified, we will take them, otherwise - * we will try to provide the default names, which follow the same rule with Hive. - */ - object ResolveGenerate extends Rule[LogicalPlan] { - // Construct the output attributes for the generator, - // The output attribute names can be either specified or - // auto generated. + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) => Some(g, names) + case _ => None + } + } + + /** + * Construct the output attributes for a [[Generator]], given a list of names. If the list of + * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + */ private def makeGeneratorOutput( generator: Generator, - generatorOutput: Seq[Attribute]): Seq[Attribute] = { + names: Seq[String]): Seq[Attribute] = { val elementTypes = generator.elementTypes - if (generatorOutput.length == elementTypes.length) { - generatorOutput.zip(elementTypes).map { - case (a, (t, nullable)) if !a.resolved => - AttributeReference(a.name, t, nullable)() - case (a, _) => a + if (names.length == elementTypes.length) { + names.zip(elementTypes).map { + case (name, (t, nullable)) => + AttributeReference(name, t, nullable)() } - } else if (generatorOutput.length == 0) { + } else if (names.isEmpty) { elementTypes.zipWithIndex.map { // keep the default column names as Hive does _c0, _c1, _cN case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() } } else { - throw new AnalysisException( - s""" - |The number of aliases supplied in the AS clause does not match - |the number of columns output by the UDTF expected - |${elementTypes.size} aliases but got ${generatorOutput.size} - """.stripMargin) + failAnalysis( + "The number of aliases supplied in the AS clause does not match the number of columns " + + s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"${names.mkString(",")} ") } } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case p: Generate if p.resolved == false => - // if the generator output names are not specified, we will use the default ones. - Generate( - p.generator, - join = p.join, - outer = p.outer, - p.qualifier, - makeGeneratorOutput(p.generator, p.generatorOutput), p.child) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0f349f9d11415..01f4b6e9bb77d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -59,6 +59,9 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + /** The set of all attributes produced by this node. */ + def generatedSet: AttributeSet = AttributeSet(generatorOutput) + override lazy val resolved: Boolean = { generator.resolved && childrenResolved && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f2f35564d12e..e1d6ac462fbcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { StructField("cField", StringType) :: Nil ))()) + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -159,10 +162,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage contains m)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) } } + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + errorTest( "unresolved attributes", testRelation.select('abcd), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8bf1320ccb71d..dc0aeea7c4aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.types._ @@ -727,6 +728,30 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def as(alias: String): Column = Alias(expr, alias)() + /** + * (Scala-specific) Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + + /** + * Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + /** * Gives the column an alias. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4fd5105c27443..2e20c3d3f4ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -593,6 +593,9 @@ class DataFrame private[sql]( def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { case Column(expr: NamedExpression) => expr + // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } // When user continuously call `select`, speed up analysis by collapsing `Project` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4404ad8ad63a8..6640631cf0719 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -363,6 +363,11 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Creates a new row for each element in the given array or map column. + */ + def explode(e: Column): Column = Explode(e.expr) + /** * Converts a string exprsesion to lower case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 269e185543059..9bdf201b3be7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,6 +27,66 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("single explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1,2,3), 1) :: + Row(1, Seq(1,2,3), 2) :: + Row(1, Seq(1,2,3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) From f9705d461350c6fccf8022e933ea909f40c53576 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 May 2015 20:49:21 -0700 Subject: [PATCH 048/109] [SPARK-7098][SQL] Make the WHERE clause with timestamp show consistent result JIRA: https://issues.apache.org/jira/browse/SPARK-7098 The WHERE clause with timstamp shows inconsistent results. This pr fixes it. Author: Liang-Chi Hsieh Closes #5682 from viirya/consistent_timestamp and squashes the following commits: 171445a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into consistent_timestamp 4e98520 [Liang-Chi Hsieh] Make the WHERE clause with timestamp show consistent result. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 6 +++--- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++++ sql/core/src/test/scala/org/apache/spark/sql/TestData.scala | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 168a4e30eab86..fe0d3f29977c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -251,10 +251,10 @@ trait HiveTypeCoercion { p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryComparison if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) + p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == DateType => p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) @@ -274,7 +274,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b)) case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(Cast(a, StringType), b)) + i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8cdbe076cbd85..479ad9fe621d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -297,6 +297,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3173 Timestamp support in the parser") { + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) + checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 446771ab2a5a5..8fbc2d23d47e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -175,7 +175,7 @@ object TestData { "4, D4, true, 2147483644" :: Nil) case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => TimestampField(new Timestamp(i)) }) timestamps.toDF().registerTempTable("timestamps") From e8f0e016eaf80a363796dd0a094291dcb3b35793 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 15 May 2015 12:04:26 +0800 Subject: [PATCH 049/109] [SQL] When creating partitioned table scan, explicitly create UnionRDD. Otherwise, it will cause stack overflow when there are many partitions. Author: Yin Huai Closes #6162 from yhuai/partitionUnionedRDD and squashes the following commits: fa016d8 [Yin Huai] Explicitly create UnionRDD. --- .../apache/spark/sql/sources/DataSourceStrategy.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index a5410cda0fe6b..ee099ab9593c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{UnionRDD, RDD} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -169,9 +169,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { scan.execute() } - val unionedRows = perPartitionRows.reduceOption(_ ++ _).getOrElse { - relation.sqlContext.emptyResult - } + val unionedRows = + if (perPartitionRows.length == 0) { + relation.sqlContext.emptyResult + } else { + new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + } createPhysicalRDD(logicalRelation.relation, output, unionedRows) } From 7da33ce5057ff965eec19ce662465b64a3564019 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 14 May 2015 23:17:41 -0700 Subject: [PATCH 050/109] [HOTFIX] Add workaround for SPARK-7660 to fix JavaAPISuite failures. --- .../spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 730d265c87f88..78e52643531e0 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -35,6 +35,7 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import org.xerial.snappy.buffer.CachedBufferAllocator; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; @@ -96,6 +97,13 @@ public OutputStream apply(OutputStream stream) { @After public void tearDown() { Utils.deleteRecursively(tempDir); + // This call is a workaround for SPARK-7660, a snappy-java bug which is exposed by this test + // suite. Clearing the cached buffer allocator's pool of reusable buffers masks this bug, + // preventing a test failure in JavaAPISuite that would otherwise occur. The underlying bug + // needs to be fixed, but in the meantime this workaround avoids spurious Jenkins failures. + synchronized (CachedBufferAllocator.class) { + CachedBufferAllocator.queueTable.clear(); + } final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { fail("Test leaked " + leakedMemory + " bytes of managed memory"); From daf4ae72fe01b6d9631bfbd061b3846bdf668dfa Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Thu, 14 May 2015 23:50:50 -0700 Subject: [PATCH 051/109] [CORE] Remove unreachable Heartbeat message from Worker It doesn't look to me Heartbeat is sent to Worker from anyone. Author: Kan Zhang Closes #6163 from kanzhang/deadwood and squashes the following commits: 56be118 [Kan Zhang] [core] Remove unreachable Heartbeat message from Worker --- .../src/main/scala/org/apache/spark/deploy/worker/Worker.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8f3cc54051048..c8df024dda355 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -324,9 +324,6 @@ private[worker] class Worker( map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) - case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") - case RegisterWorkerFailed(message) => if (!registered) { logError("Worker registration failed: " + message) From cf842d42a70398671c4bc5ebfa70f6fdb8c57c7f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 14 May 2015 23:51:41 -0700 Subject: [PATCH 052/109] [SPARK-7650] [STREAMING] [WEBUI] Move streaming css and js files to the streaming project cc tdas Author: zsxwing Closes #6160 from zsxwing/SPARK-7650 and squashes the following commits: fe6ae15 [zsxwing] Fix the import order a4ffd99 [zsxwing] Merge branch 'master' into SPARK-7650 dc402b6 [zsxwing] Move streaming css and js files to the streaming project --- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../spark/streaming}/ui/static/streaming-page.css | 0 .../spark/streaming}/ui/static/streaming-page.js | 0 .../apache/spark/streaming/ui/StreamingPage.scala | 4 ++-- .../org/apache/spark/streaming/ui/StreamingTab.scala | 12 +++++++++++- 5 files changed, 14 insertions(+), 4 deletions(-) rename {core/src/main/resources/org/apache/spark => streaming/src/main/resources/org/apache/spark/streaming}/ui/static/streaming-page.css (100%) rename {core/src/main/resources/org/apache/spark => streaming/src/main/resources/org/apache/spark/streaming}/ui/static/streaming-page.js (100%) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 384f2ad26e281..1df9cd0fa18b4 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -94,7 +94,7 @@ private[spark] abstract class WebUI( } /** Detach a handler from this UI. */ - protected def detachHandler(handler: ServletContextHandler) { + def detachHandler(handler: ServletContextHandler) { handlers -= handler serverInfo.foreach { info => info.rootHandler.removeHandler(handler) diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css similarity index 100% rename from core/src/main/resources/org/apache/spark/ui/static/streaming-page.css rename to streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js similarity index 100% rename from core/src/main/resources/org/apache/spark/ui/static/streaming-page.js rename to streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 070564aa10633..4ee7a486e370b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -166,8 +166,8 @@ private[ui] class StreamingPage(parent: StreamingTab) private def generateLoadResources(): Seq[Node] = { // scalastyle:off - - + + // scalastyle:on } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index f307b54bb9630..e0c0f57212f55 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.ui +import org.eclipse.jetty.servlet.ServletContextHandler + import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} import StreamingTab._ @@ -30,6 +32,8 @@ import StreamingTab._ private[spark] class StreamingTab(val ssc: StreamingContext) extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" + val parent = getSparkUI(ssc) val listener = ssc.progressListener @@ -38,12 +42,18 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) + var staticHandler: ServletContextHandler = null + def attach() { getSparkUI(ssc).attachTab(this) + staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") + getSparkUI(ssc).attachHandler(staticHandler) } def detach() { getSparkUI(ssc).detachTab(this) + getSparkUI(ssc).detachHandler(staticHandler) + staticHandler = null } } From 94761485b207fa1f12a8410a68920300d851bf61 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 15 May 2015 00:18:39 -0700 Subject: [PATCH 053/109] [SPARK-6258] [MLLIB] GaussianMixture Python API parity check Implement Python API for major disparities of GaussianMixture cluster algorithm between Scala & Python ```scala GaussianMixture setInitialModel GaussianMixtureModel k ``` Author: Yanbo Liang Closes #6087 from yanboliang/spark-6258 and squashes the following commits: b3af21c [Yanbo Liang] fix typo 2b645c1 [Yanbo Liang] fix doc 638b4b7 [Yanbo Liang] address comments b5bcade [Yanbo Liang] GaussianMixture Python API parity check --- .../mllib/api/python/PythonMLLibAPI.scala | 24 +++++-- .../clustering/GaussianMixtureModel.scala | 9 ++- python/pyspark/mllib/clustering.py | 67 +++++++++++++++---- 3 files changed, 75 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f4c477596557f..2fa54df6fc2b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -345,28 +345,40 @@ private[python] class PythonMLLibAPI extends Serializable { * Returns a list containing weights, mean and covariance of each mixture component. */ def trainGaussianMixture( - data: JavaRDD[Vector], - k: Int, - convergenceTol: Double, + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, maxIterations: Int, - seed: java.lang.Long): JList[Object] = { + seed: java.lang.Long, + initialModelWeights: java.util.ArrayList[Double], + initialModelMu: java.util.ArrayList[Vector], + initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) + if (initialModelWeights != null && initialModelMu != null && initialModelSigma != null) { + val gaussians = initialModelMu.asScala.toSeq.zip(initialModelSigma.asScala.toSeq).map { + case (x, y) => new MultivariateGaussian(x.asInstanceOf[Vector], y.asInstanceOf[Matrix]) + } + val initialModel = new GaussianMixtureModel( + initialModelWeights.asScala.toArray, gaussians.toArray) + gmmAlg.setInitialModel(initialModel) + } + if (seed != null) gmmAlg.setSeed(seed) try { val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] + var mu = ArrayBuffer.empty[Vector] var sigma = ArrayBuffer.empty[Matrix] for (i <- 0 until model.k) { wt += model.weights(i) mu += model.gaussians(i).mu sigma += model.gaussians(i).sigma - } + } List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava } finally { data.rdd.unpersist(blocking = false) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index ec65a3da689de..c22862c130e77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -38,11 +38,10 @@ import org.apache.spark.sql.{SQLContext, Row} * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. * - * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is - * the weight for Gaussian i, and weight.sum == 1 - * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i - * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the - * covariance matrix for Gaussian i + * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is + * the weight for Gaussian i, and weights.sum == 1 + * @param gaussians Array of MultivariateGaussian where gaussians(i) represents + * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ @Experimental class GaussianMixtureModel( diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 04e67158514f5..a53333dae6a82 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -142,6 +142,7 @@ class GaussianMixtureModel(object): """A clustering model derived from the Gaussian Mixture Model method. + >>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -154,11 +155,12 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True - >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, - ... -5.2211, -5.0602, 4.7118, - ... 6.8989, 3.4592, 4.6322, - ... 5.7048, 4.6567, 5.5026, - ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + >>> data = array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]) + >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() @@ -166,12 +168,38 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True + >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) + >>> im = GaussianMixtureModel([0.5, 0.5], + ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), + ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) + >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ def __init__(self, weights, gaussians): - self.weights = weights - self.gaussians = gaussians - self.k = len(self.weights) + self._weights = weights + self._gaussians = gaussians + self._k = len(self._weights) + + @property + def weights(self): + """ + Weights for each Gaussian distribution in the mixture, where weights[i] is + the weight for Gaussian i, and weights.sum == 1. + """ + return self._weights + + @property + def gaussians(self): + """ + Array of MultivariateGaussian where gaussians[i] represents + the Multivariate Gaussian (Normal) Distribution for Gaussian i. + """ + return self._gaussians + + @property + def k(self): + """Number of gaussians in mixture.""" + return self._k def predict(self, x): """ @@ -193,9 +221,9 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self.weights), means, sigmas) + _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) @@ -208,13 +236,24 @@ class GaussianMixture(object): :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 :param maxIterations: Number of iterations. Default to 100 :param seed: Random Seed + :param initialModel: GaussianMixtureModel for initializing learning """ @classmethod - def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """Train a Gaussian Mixture clustering model.""" - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", - rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed) + initialModelWeights = None + initialModelMu = None + initialModelSigma = None + if initialModel is not None: + if initialModel.k != k: + raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" + % (initialModel.k, k)) + initialModelWeights = initialModel.weights + initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] + initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] + weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, + convergenceTol, maxIterations, seed, initialModelWeights, + initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) From fdf5bba35d201fe0de3901b4d47262c485c76569 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 15 May 2015 16:20:49 +0800 Subject: [PATCH 054/109] [SPARK-7591] [SQL] Partitioning support API tweaks Please see [SPARK-7591] [1] for the details. /cc rxin marmbrus yhuai [1]: https://issues.apache.org/jira/browse/SPARK-7591 Author: Cheng Lian Closes #6150 from liancheng/spark-7591 and squashes the following commits: af422e7 [Cheng Lian] Addresses @rxin's comments 37d1738 [Cheng Lian] Fixes HadoopFsRelation partition columns initialization 2fc680a [Cheng Lian] Fixes Scala style issue 189ad23 [Cheng Lian] Removes HadoopFsRelation constructor arguments 522c24e [Cheng Lian] Adds OutputWriterFactory 047d40d [Cheng Lian] Renames FSBased* to HadoopFs*, also renamed FSBasedParquetRelation back to ParquetRelation2 --- .../org/apache/spark/sql/SQLContext.scala | 14 +- ...{fsBasedParquet.scala => newParquet.scala} | 71 ++++----- .../sql/sources/DataSourceStrategy.scala | 10 +- .../spark/sql/sources/PartitioningUtils.scala | 4 + .../apache/spark/sql/sources/commands.scala | 23 ++- .../org/apache/spark/sql/sources/ddl.scala | 8 +- .../apache/spark/sql/sources/interfaces.scala | 140 +++++++++--------- .../org/apache/spark/sql/sources/rules.scala | 2 +- .../sql/parquet/ParquetFilterSuite.scala | 2 +- .../sql/parquet/ParquetSchemaSuite.scala | 12 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../spark/sql/hive/execution/commands.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 8 +- .../apache/spark/sql/hive/parquetSuites.scala | 20 +-- .../sql/sources/SimpleTextRelation.scala | 47 +++--- ...tes.scala => hadoopFsRelationSuites.scala} | 8 +- 17 files changed, 195 insertions(+), 194 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/parquet/{fsBasedParquet.scala => newParquet.scala} (92%) rename sql/hive/src/test/scala/org/apache/spark/sql/sources/{fsBasedRelationSuites.scala => hadoopFsRelationSuites.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b33a700208014..9fb355eb81939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json._ -import org.apache.spark.sql.parquet.FSBasedParquetRelation +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -610,7 +610,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } else if (conf.parquetUseDataSourceApi) { val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray baseRelationToDataFrame( - new FSBasedParquetRelation( + new ParquetRelation2( globbedPaths.map(_.toString), None, None, Map.empty[String, String])(this)) } else { DataFrame(this, parquet.ParquetRelation( @@ -989,7 +989,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def jdbc(url: String, table: String): DataFrame = { jdbc(url, table, JDBCRelation.columnPartition(null), new Properties()) } - + /** * :: Experimental :: * Construct a [[DataFrame]] representing the database table accessible via JDBC URL @@ -1002,7 +1002,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def jdbc(url: String, table: String, properties: Properties): DataFrame = { jdbc(url, table, JDBCRelation.columnPartition(null), properties) } - + /** * :: Experimental :: * Construct a [[DataFrame]] representing the database table accessible via JDBC URL @@ -1020,7 +1020,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @Experimental def jdbc( url: String, - table: String, + table: String, columnName: String, lowerBound: Long, upperBound: Long, @@ -1056,7 +1056,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val parts = JDBCRelation.columnPartition(partitioning) jdbc(url, table, parts, properties) } - + /** * :: Experimental :: * Construct a [[DataFrame]] representing the database table accessible via JDBC URL @@ -1093,7 +1093,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } jdbc(url, table, parts, properties) } - + private def jdbc( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala rename to sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c83a9c35dbddf..946062f6ea64e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/fsBasedParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -41,27 +41,23 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLConf, SQLContext} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} -private[sql] class DefaultSource extends FSBasedRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation = { + parameters: Map[String, String]): HadoopFsRelation = { val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty)) - new FSBasedParquetRelation(paths, schema, partitionSpec, parameters)(sqlContext) + new ParquetRelation2(paths, schema, partitionSpec, parameters)(sqlContext) } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter extends OutputWriter { - private var recordWriter: RecordWriter[Void, Row] = _ - private var taskAttemptContext: TaskAttemptContext = _ - - override def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, Row] = { val conf = context.getConfiguration val outputFormat = { // When appending new Parquet files to an existing Parquet file directory, to avoid @@ -77,7 +73,7 @@ private[sql] class ParquetOutputWriter extends OutputWriter { if (fs.exists(outputPath)) { // Pattern used to match task ID in part file names, e.g.: // - // part-r-00001.gz.part + // part-r-00001.gz.parquet // ^~~~~ val partFilePattern = """part-.-(\d{1,}).*""".r @@ -86,9 +82,8 @@ private[sql] class ParquetOutputWriter extends OutputWriter { case name if name.startsWith("_") => 0 case name if name.startsWith(".") => 0 case name => sys.error( - s"""Trying to write Parquet files to directory $outputPath, - |but found items with illegal name "$name" - """.stripMargin.replace('\n', ' ').trim) + s"Trying to write Parquet files to directory $outputPath, " + + s"but found items with illegal name '$name'.") }.reduceOption(_ max _).getOrElse(0) } else { 0 @@ -111,37 +106,39 @@ private[sql] class ParquetOutputWriter extends OutputWriter { } } - recordWriter = outputFormat.getRecordWriter(context) - taskAttemptContext = context + outputFormat.getRecordWriter(context) } override def write(row: Row): Unit = recordWriter.write(null, row) - override def close(): Unit = recordWriter.close(taskAttemptContext) + override def close(): Unit = recordWriter.close(context) } -private[sql] class FSBasedParquetRelation( - paths: Array[String], +private[sql] class ParquetRelation2( + override val paths: Array[String], private val maybeDataSchema: Option[StructType], private val maybePartitionSpec: Option[PartitionSpec], parameters: Map[String, String])( val sqlContext: SQLContext) - extends FSBasedRelation(paths, maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec) with Logging { // Should we merge schemas from all Parquet part-files? private val shouldMergeSchemas = - parameters.getOrElse(FSBasedParquetRelation.MERGE_SCHEMA, "true").toBoolean + parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean private val maybeMetastoreSchema = parameters - .get(FSBasedParquetRelation.METASTORE_SCHEMA) + .get(ParquetRelation2.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) - private val metadataCache = new MetadataCache - metadataCache.refresh() + private lazy val metadataCache: MetadataCache = { + val meta = new MetadataCache + meta.refresh() + meta + } override def equals(other: scala.Any): Boolean = other match { - case that: FSBasedParquetRelation => + case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { this.shouldMergeSchemas == that.shouldMergeSchemas } else { @@ -175,8 +172,6 @@ private[sql] class FSBasedParquetRelation( } } - override def outputWriterClass: Class[_ <: OutputWriter] = classOf[ParquetOutputWriter] - override def dataSchema: StructType = metadataCache.dataSchema override private[sql] def refresh(): Unit = { @@ -187,9 +182,12 @@ private[sql] class FSBasedParquetRelation( // Parquet data source always uses Catalyst internal representations. override val needConversion: Boolean = false - override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum + override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + + override def userDefinedPartitionColumns: Option[StructType] = + maybePartitionSpec.map(_.partitionColumns) - override def prepareForWrite(job: Job): Unit = { + override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) val committerClass = @@ -224,6 +222,13 @@ private[sql] class FSBasedParquetRelation( .getOrElse( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + } } override def buildScan( @@ -385,7 +390,7 @@ private[sql] class FSBasedParquetRelation( // case insensitivity issue and possible schema mismatch (probably caused by schema // evolution). maybeMetastoreSchema - .map(FSBasedParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) .getOrElse(dataSchema0) } } @@ -439,12 +444,12 @@ private[sql] class FSBasedParquetRelation( "No schema defined, " + s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") - FSBasedParquetRelation.readSchema(filesToTouch.map(footers.apply), sqlContext) + ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) } } } -private[sql] object FSBasedParquetRelation extends Logging { +private[sql] object ParquetRelation2 extends Logging { // Whether we should merge schemas collected from all Parquet part-files. private[sql] val MERGE_SCHEMA = "mergeSchema" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index ee099ab9593c7..e6324b20b3065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -59,7 +59,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, _) => t.buildScan(a)) :: Nil // Scanning partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray @@ -87,7 +87,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { selectedPartitions) :: Nil // Scanning non-partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) => + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) => val inputPaths = t.paths.map(new Path(_)).flatMap { path => val fs = path.getFileSystem(t.sqlContext.sparkContext.hadoopConfiguration) val qualifiedPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -111,10 +111,10 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: FSBasedRelation), part, query, overwrite, false) if part.isEmpty => + l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) if part.isEmpty => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append execution.ExecutedCommand( - InsertIntoFSBasedRelation(t, query, Array.empty[String], mode)) :: Nil + InsertIntoHadoopFsRelation(t, query, Array.empty[String], mode)) :: Nil case _ => Nil } @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { partitionColumns: StructType, partitions: Array[Partition]) = { val output = projections.map(_.toAttribute) - val relation = logicalRelation.relation.asInstanceOf[FSBasedRelation] + val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index d30f7f65e21c0..d1f0cdab55f66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -35,6 +35,10 @@ private[sql] case class Partition(values: Row, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) private[sql] object PartitioningUtils { + // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't + // depend on Hive. + private[sql] val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 7879328bbaaab..a09bb08de736a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -58,8 +58,8 @@ private[sql] case class InsertIntoDataSource( } } -private[sql] case class InsertIntoFSBasedRelation( - @transient relation: FSBasedRelation, +private[sql] case class InsertIntoHadoopFsRelation( + @transient relation: HadoopFsRelation, @transient query: LogicalPlan, partitionColumns: Array[String], mode: SaveMode) @@ -102,7 +102,7 @@ private[sql] case class InsertIntoFSBasedRelation( insert(new DefaultWriterContainer(relation, job), df) } else { val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__") + relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) } } @@ -234,7 +234,7 @@ private[sql] case class InsertIntoFSBasedRelation( } private[sql] abstract class BaseWriterContainer( - @transient val relation: FSBasedRelation, + @transient val relation: HadoopFsRelation, @transient job: Job) extends SparkHadoopMapReduceUtil with Logging @@ -261,7 +261,7 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass + protected var outputWriterFactory: OutputWriterFactory = _ private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ @@ -269,7 +269,7 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - relation.prepareForWrite(job) + outputWriterFactory = relation.prepareJobForWrite(job) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -346,16 +346,15 @@ private[sql] abstract class BaseWriterContainer( } private[sql] class DefaultWriterContainer( - @transient relation: FSBasedRelation, + @transient relation: HadoopFsRelation, @transient job: Job) extends BaseWriterContainer(relation, job) { @transient private var writer: OutputWriter = _ override protected def initWriters(): Unit = { - writer = outputWriterClass.newInstance() taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer.init(getWorkPath, dataSchema, taskAttemptContext) + writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) } override def outputWriterForRow(row: Row): OutputWriter = writer @@ -372,7 +371,7 @@ private[sql] class DefaultWriterContainer( } private[sql] class DynamicPartitionWriterContainer( - @transient relation: FSBasedRelation, + @transient relation: HadoopFsRelation, @transient job: Job, partitionColumns: Array[String], defaultPartitionName: String) @@ -398,12 +397,10 @@ private[sql] class DynamicPartitionWriterContainer( outputWriters.getOrElseUpdate(partitionPath, { val path = new Path(getWorkPath, partitionPath) - val writer = outputWriterClass.newInstance() taskAttemptContext.getConfiguration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - writer.init(path.toString, dataSchema, taskAttemptContext) - writer + outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 595c5eb40e295..37a569db311ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -226,7 +226,7 @@ private[sql] object ResolvedDataSource { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => val maybePartitionsSchema = if (partitionColumns.isEmpty) { None } else { @@ -256,7 +256,7 @@ private[sql] object ResolvedDataSource { case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) @@ -296,7 +296,7 @@ private[sql] object ResolvedDataSource { val relation = clazz.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -315,7 +315,7 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) sqlContext.executePlan( - InsertIntoFSBasedRelation( + InsertIntoHadoopFsRelation( r, data.logicalPlan, partitionColumns.toArray, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6f315305c11d6..274ab4485217a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, _} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.types.{StructField, StructType} @@ -94,7 +94,7 @@ trait SchemaRelationProvider { * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a - * USING clause specified (to specify the implemented [[FSBasedRelationProvider]]), a user defined + * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined * schema, and an optional list of partition columns, this interface is used to pass in the * parameters specified by a user. * @@ -105,15 +105,15 @@ trait SchemaRelationProvider { * * A new instance of this class with be instantiated each time a DDL call is made. * - * The difference between a [[RelationProvider]] and a [[FSBasedRelationProvider]] is + * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when * using a SchemaRelationProvider. A relation provider can inherits both [[RelationProvider]], - * and [[FSBasedRelationProvider]] if it can support schema inference, user-specified + * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified * schemas, and accessing partitioned relations. * * @since 1.4.0 */ -trait FSBasedRelationProvider { +trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity @@ -124,7 +124,7 @@ trait FSBasedRelationProvider { paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation + parameters: Map[String, String]): HadoopFsRelation } /** @@ -280,33 +280,42 @@ trait CatalystScan { /** * ::Experimental:: - * [[OutputWriter]] is used together with [[FSBasedRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. * * @since 1.4.0 */ @Experimental -abstract class OutputWriter { +abstract class OutputWriterFactory extends Serializable { /** - * Initializes this [[OutputWriter]] before any rows are persisted. + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. * * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that * this may not point to the final output file. For example, `FileOutputFormat` writes to * temporary directories and then merge written files back to the final destination. In * this case, `path` points to a temporary output file under the temporary directory. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the corresponding relation is partitioned. + * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. * * @since 1.4.0 */ - def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = () + def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter +} +/** + * ::Experimental:: + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriter { /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. @@ -333,74 +342,71 @@ abstract class OutputWriter { * filter using selected predicates before producing an RDD containing all matching tuples as * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file * systems, it's able to discover partitioning information from the paths of input directories, and - * perform partition pruning before start reading the data. Subclasses of [[FSBasedRelation()]] must - * override one of the three `buildScan` methods to implement the read path. + * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] + * must override one of the three `buildScan` methods to implement the read path. * * For the write path, it provides the ability to write to both non-partitioned and partitioned * tables. Directory layout of the partitioned tables is compatible with Hive. * * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for * implementing metastore table conversion. - * @param paths Base paths of this relation. For partitioned relations, it should be the root - * directories of all partition directories. - * @param maybePartitionSpec An [[FSBasedRelation]] can be created with an optional + * + * @param maybePartitionSpec An [[HadoopFsRelation]] can be created with an optional * [[PartitionSpec]], so that partition discovery can be skipped. * * @since 1.4.0 */ @Experimental -abstract class FSBasedRelation private[sql]( - val paths: Array[String], - maybePartitionSpec: Option[PartitionSpec]) +abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) extends BaseRelation { + def this() = this(None) + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + private val codegenEnabled = sqlContext.conf.codegenEnabled + + private var _partitionSpec: PartitionSpec = _ + + final private[sql] def partitionSpec: PartitionSpec = { + if (_partitionSpec == null) { + _partitionSpec = maybePartitionSpec + .map(spec => spec.copy(partitionColumns = spec.partitionColumns.asNullable)) + .orElse(userDefinedPartitionColumns.map(PartitionSpec(_, Array.empty[Partition]))) + .getOrElse { + if (sqlContext.conf.partitionDiscoveryEnabled()) { + discoverPartitions() + } else { + PartitionSpec(StructType(Nil), Array.empty[Partition]) + } + } + } + _partitionSpec + } + /** - * Constructs an [[FSBasedRelation]]. - * - * @param paths Base paths of this relation. For partitioned relations, it should be either root - * directories of all partition directories. - * @param partitionColumns Partition columns of this relation. + * Base paths of this relation. For partitioned relations, it should be either root directories + * of all partition directories. * * @since 1.4.0 */ - def this(paths: Array[String], partitionColumns: StructType) = - this(paths, { - if (partitionColumns.isEmpty) None - else Some(PartitionSpec(partitionColumns, Array.empty[Partition])) - }) + def paths: Array[String] /** - * Constructs an [[FSBasedRelation]]. - * - * @param paths Base paths of this relation. For partitioned relations, it should be root - * directories of all partition directories. + * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically + * discovered. Note that they should always be nullable. * * @since 1.4.0 */ - def this(paths: Array[String]) = this(paths, None) - - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) - - private val codegenEnabled = sqlContext.conf.codegenEnabled - - private var _partitionSpec: PartitionSpec = maybePartitionSpec.map { spec => - spec.copy(partitionColumns = spec.partitionColumns.asNullable) - }.getOrElse { - if (sqlContext.conf.partitionDiscoveryEnabled()) { - discoverPartitions() - } else { - PartitionSpec(StructType(Nil), Array.empty[Partition]) - } - } - - private[sql] def partitionSpec: PartitionSpec = _partitionSpec + final def partitionColumns: StructType = + userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) /** - * Partition columns. Note that they are always nullable. + * Optional user defined partition columns. * * @since 1.4.0 */ - def partitionColumns: StructType = partitionSpec.partitionColumns + def userDefinedPartitionColumns: Option[StructType] = None private[sql] def refresh(): Unit = { if (sqlContext.conf.partitionDiscoveryEnabled()) { @@ -419,7 +425,7 @@ abstract class FSBasedRelation private[sql]( }.map(_.getPath) if (leafDirs.nonEmpty) { - PartitioningUtils.parsePartitions(leafDirs, "__HIVE_DEFAULT_PARTITION__") + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) } else { PartitionSpec(StructType(Array.empty[StructField]), Array.empty[Partition]) } @@ -458,7 +464,7 @@ abstract class FSBasedRelation private[sql]( * @since 1.4.0 */ def buildScan(inputPaths: Array[String]): RDD[Row] = { - throw new RuntimeException( + throw new UnsupportedOperationException( "At least one buildScan() method should be overridden to read the relation.") } @@ -520,8 +526,8 @@ abstract class FSBasedRelation private[sql]( } /** - * Client side preparation for data writing can be put here. For example, user defined output - * committer can be configured here. + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here. * * Note that the only side effect expected here is mutating `job` via its setters. Especially, * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states @@ -529,13 +535,5 @@ abstract class FSBasedRelation private[sql]( * * @since 1.4.0 */ - def prepareForWrite(job: Job): Unit = () - - /** - * This method is responsible for producing a new [[OutputWriter]] for each newly opened output - * file on the executor side. - * - * @since 1.4.0 - */ - def outputWriterClass: Class[_ <: OutputWriter] + def prepareJobForWrite(job: Job): OutputWriterFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index aad1d248d0a28..1eacdde7413f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -102,7 +102,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) => // OK - case logical.InsertIntoTable(LogicalRelation(_: FSBasedRelation), _, _, _, _) => // OK + case logical.InsertIntoTable(LogicalRelation(_: HadoopFsRelation), _, _, _, _) => // OK case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 3bbc5b05868af..5ad439584716f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -63,7 +63,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { }.flatten.reduceOption(_ && _) val forParquetDataSource = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: FSBasedParquetRelation)) => filters + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters }.flatten.reduceOption(_ && _) forParquetTableScan.orElse(forParquetDataSource) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index fc90e3edce7fe..c964b6d984557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -204,7 +204,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -219,7 +219,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -230,7 +230,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -241,7 +241,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Conflicting non-nullable field names intercept[Throwable] { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -255,7 +255,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -268,7 +268,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - FSBasedParquetRelation.mergeMetastoreParquetSchema( + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b0e82c8d033b2..2aa80b47a97e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.FSBasedParquetRelation +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} @@ -226,8 +226,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. val parquetOptions = Map( - FSBasedParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - FSBasedParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -238,7 +238,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: FSBasedParquetRelation) => + case logical@LogicalRelation(parquetRelation: ParquetRelation2) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = @@ -281,7 +281,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new FSBasedParquetRelation( + new ParquetRelation2( paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created @@ -294,7 +294,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new FSBasedParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) + new ParquetRelation2(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 8e405e080489f..6609763343752 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -194,7 +194,7 @@ case class CreateMetastoreDataSourceAsSelect( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { - case l @ LogicalRelation(_: InsertableRelation | _: FSBasedRelation) => + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table $tableName because the resolved relation does not " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index da5d203d9d343..1bf1c1be3e3d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.FSBasedParquetRelation +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -579,11 +579,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: FSBasedParquetRelation) => // OK + case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[FSBasedParquetRelation].getCanonicalName}") + s"${classOf[ParquetRelation2].getCanonicalName}") } // Clenup and reset confs. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5c7152e2140db..dfe73c62c42b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} -import org.apache.spark.sql.parquet.FSBasedParquetRelation +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -175,17 +175,17 @@ class SQLQuerySuite extends QueryTest { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { - case LogicalRelation(r: FSBasedParquetRelation) => + case LogicalRelation(r: ParquetRelation2) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${FSBasedParquetRelation.getClass.getCanonicalName}.") + s"${ParquetRelation2.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${FSBasedParquetRelation.getClass.getCanonicalName} is expected, but found " + + s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 41bcbe84b0ef2..b6be09e2f8837 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.{FSBasedParquetRelation, ParquetTableScan} -import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoFSBasedRelation, LogicalRelation} +import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} +import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} import org.apache.spark.util.Utils @@ -291,10 +291,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: FSBasedParquetRelation) => // OK + case LogicalRelation(_: ParquetRelation2) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[FSBasedParquetRelation].getCanonicalName}") + s"${classOf[ParquetRelation2].getCanonicalName}") } sql("DROP TABLE IF EXISTS test_parquet_ctas") @@ -315,9 +315,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoFSBasedRelation(_: FSBasedParquetRelation, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[FSBasedParquetRelation].getCanonicalName} and " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -345,9 +345,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoFSBasedRelation(r: FSBasedParquetRelation, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[FSBasedParquetRelation].getCanonicalName} and " + + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + s"However, found a ${o.toString} ") } @@ -378,7 +378,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: FSBasedParquetRelation) => r + case r @ LogicalRelation(_: ParquetRelation2) => r }.size } @@ -390,7 +390,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: FSBasedParquetRelation) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 8801aba2f64c3..29b21586f9c2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -24,7 +24,7 @@ import com.google.common.base.Objects import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -32,17 +32,16 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} /** - * A simple example [[FSBasedRelationProvider]]. + * A simple example [[HadoopFsRelationProvider]]. */ -class SimpleTextSource extends FSBasedRelationProvider { +class SimpleTextSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation = { - val partitionsSchema = partitionColumns.getOrElse(StructType(Array.empty[StructField])) - new SimpleTextRelation(paths, schema, partitionsSchema, parameters)(sqlContext) + parameters: Map[String, String]): HadoopFsRelation = { + new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) } } @@ -59,38 +58,30 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW } } -class SimpleTextOutputWriter extends OutputWriter { - private var recordWriter: RecordWriter[NullWritable, Text] = _ - private var taskAttemptContext: TaskAttemptContext = _ - - override def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { - recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) - taskAttemptContext = context - } +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { + private val recordWriter: RecordWriter[NullWritable, Text] = + new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map(_.toString).mkString(",") recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(taskAttemptContext) + override def close(): Unit = recordWriter.close(context) } /** - * A simple example [[FSBasedRelation]], used for testing purposes. Data are stored as comma + * A simple example [[HadoopFsRelation]], used for testing purposes. Data are stored as comma * separated string lines. When scanning data, schema must be explicitly provided via data source * option `"dataSchema"`. */ class SimpleTextRelation( - paths: Array[String], + override val paths: Array[String], val maybeDataSchema: Option[StructType], - partitionsSchema: StructType, + override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends FSBasedRelation(paths, partitionsSchema) { + extends HadoopFsRelation { import sqlContext.sparkContext @@ -110,9 +101,6 @@ class SimpleTextRelation( override def hashCode(): Int = Objects.hashCode(paths, maybeDataSchema, dataSchema) - override def outputWriterClass: Class[_ <: OutputWriter] = - classOf[SimpleTextOutputWriter] - override def buildScan(inputPaths: Array[String]): RDD[Row] = { val fields = dataSchema.map(_.dataType) @@ -122,4 +110,13 @@ class SimpleTextRelation( }: _*) } } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala similarity index 98% rename from sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 394833f22907d..cf6afd25ae5a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/fsBasedRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ // TODO Don't extend ParquetTest // This test suite extends ParquetTest for some convenient utility methods. These methods should be // moved to some more general places, maybe QueryTest. -class FSBasedRelationTest extends QueryTest with ParquetTest { +class HadoopFsRelationTest extends QueryTest with ParquetTest { override val sqlContext: SQLContext = TestHive import sqlContext._ @@ -487,7 +487,7 @@ class FSBasedRelationTest extends QueryTest with ParquetTest { } val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: FSBasedRelation) => + case LogicalRelation(relation: HadoopFsRelation) => relation.paths.toSet }.getOrElse { fail("Expect an FSBasedRelation, but none could be found") @@ -499,7 +499,7 @@ class FSBasedRelationTest extends QueryTest with ParquetTest { } } -class SimpleTextRelationSuite extends FSBasedRelationTest { +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName import sqlContext._ @@ -530,7 +530,7 @@ class SimpleTextRelationSuite extends FSBasedRelationTest { } } -class FSBasedParquetRelationSuite extends FSBasedRelationTest { +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName import sqlContext._ From c64ff8036cc6bc7c87743f4c751d7fe91c2e366a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 15 May 2015 11:37:34 +0100 Subject: [PATCH 055/109] [SPARK-7503] [YARN] Resources in .sparkStaging directory can't be cleaned up on error When we run applications on YARN with cluster mode, uploaded resources on .sparkStaging directory can't be cleaned up in case of failure of uploading local resources. You can see this issue by running following command. ``` bin/spark-submit --master yarn --deploy-mode cluster --class ``` Author: Kousuke Saruta Closes #6026 from sarutak/delete-uploaded-resources-on-error and squashes the following commits: caef9f4 [Kousuke Saruta] Fixed style 882f921 [Kousuke Saruta] Wrapped Client#submitApplication with try/catch blocks in order to delete resources on error 1786ca4 [Kousuke Saruta] Merge branch 'master' of https://github.com/apache/spark into delete-uploaded-resources-on-error f61071b [Kousuke Saruta] Fixed cleanup problem --- .../org/apache/spark/deploy/yarn/Client.scala | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d21a7393478ce..7e023f2d92578 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction @@ -91,30 +91,52 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { - // Setup the credentials before doing anything else, so we have don't have issues at any point. - setupCredentials() - yarnClient.init(yarnConf) - yarnClient.start() - - logInfo("Requesting a new application from cluster with %d NodeManagers" - .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - - // Get a new application from our RM - val newApp = yarnClient.createApplication() - val newAppResponse = newApp.getNewApplicationResponse() - val appId = newAppResponse.getApplicationId() - - // Verify whether the cluster has enough resources for our AM - verifyClusterResources(newAppResponse) - - // Set up the appropriate contexts to launch our AM - val containerContext = createContainerLaunchContext(newAppResponse) - val appContext = createApplicationSubmissionContext(newApp, containerContext) - - // Finally, submit and monitor the application - logInfo(s"Submitting application ${appId.getId} to ResourceManager") - yarnClient.submitApplication(appContext) - appId + var appId: ApplicationId = null + try { + // Setup the credentials before doing anything else, + // so we have don't have issues at any point. + setupCredentials() + yarnClient.init(yarnConf) + yarnClient.start() + + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + + // Get a new application from our RM + val newApp = yarnClient.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + appId = newAppResponse.getApplicationId() + + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) + + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) + + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) + appId + } catch { + case e: Throwable => + if (appId != null) { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + throw e + } } /** From f96b85ab44b82736363764ea39ee62884007f4a3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 May 2015 10:03:29 -0700 Subject: [PATCH 056/109] [SPARK-7668] [MLLIB] Preserve isTransposed property for Matrix after calling map function JIRA: https://issues.apache.org/jira/browse/SPARK-7668 Author: Liang-Chi Hsieh Closes #6188 from viirya/fix_matrix_map and squashes the following commits: 2a7cc97 [Liang-Chi Hsieh] Preserve isTransposed property for Matrix after calling map function. --- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 3fa5e068d16d4..a609674df6b8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -273,7 +273,8 @@ class DenseMatrix( override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) + private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), + isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { val len = values.length @@ -535,7 +536,7 @@ class SparseMatrix( } private[mllib] def map(f: Double => Double) = - new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { val len = values.length From 8f4aaba0e4e3350ab152a476d08ff60e9495c6d2 Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Fri, 15 May 2015 10:43:18 -0700 Subject: [PATCH 057/109] [SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD. Author: FlytxtRnD Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits: 4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD --- python/pyspark/mllib/clustering.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a53333dae6a82..b55583f82223f 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -212,6 +212,9 @@ def predict(self, x): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -225,6 +228,9 @@ def predictSoft(self, x): membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) class GaussianMixture(object): From b1b9d5802e3d185f42711ab043a21c9d1eb4763f Mon Sep 17 00:00:00 2001 From: Oleksii Kostyliev Date: Fri, 15 May 2015 11:19:56 -0700 Subject: [PATCH 058/109] [SPARK-7233] [CORE] Detect REPL mode once

Description

Detect REPL mode once per JVM lifespan. Previous behavior was to check presence of interpreter mode every time a job was submitted. In the case of execution of multiple short-living jobs this was causing massive mutual blocks between submission threads. For more details please refer to https://issues.apache.org/jira/browse/SPARK-7233.

Notes

* I inverted the return value in case of catching an exception from `true` to `false`. It seems more logical to assume that if the REPL class is not found, we aren't in the interpreter mode. * I'd personally would call `classForName` with just a Spark classloader (`org.apache.spark.util.Utils#getSparkClassLoader`) but `org.apache.spark.util.Utils#getContextOrSparkClassLoader` is said to be preferable. * I struggled to come up with a concise, readable and clear unit test. Suggestions are welcome if you feel necessary. Author: Oleksii Kostyliev Author: Oleksii Kostyliev Closes #5835 from preeze/SPARK-7233 and squashes the following commits: 69bb9e4 [Oleksii Kostyliev] SPARK-7527: fixed explanatory comment to meet style-checker requirements 26dcc24 [Oleksii Kostyliev] SPARK-7527: fixed explanatory comment to meet style-checker requirements c6f9685 [Oleksii Kostyliev] Merge remote-tracking branch 'remotes/upstream/master' into SPARK-7233 b78a983 [Oleksii Kostyliev] SPARK-7527: revert the fix and let it be addressed separately at a later stage b64d441 [Oleksii Kostyliev] SPARK-7233: inline inInterpreter parameter into instantiateClass 86e2606 [Oleksii Kostyliev] SPARK-7233, SPARK-7527: Handle interpreter mode properly. c7ee69c [Oleksii Kostyliev] Merge remote-tracking branch 'upstream/master' into SPARK-7233 d6c07fc [Oleksii Kostyliev] SPARK-7233: properly handle the inverted meaning of isInInterpreter c319039 [Oleksii Kostyliev] SPARK-7233: move inInterpreter to Utils and make it lazy --- .../org/apache/spark/util/ClosureCleaner.scala | 16 +++------------- .../main/scala/org/apache/spark/util/Utils.scala | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6fe32e469c732..6f2966bd4fd31 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -239,15 +239,6 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" + fields accessed by starting closure: " + accessedFields.size) accessedFields.foreach { f => logDebug(" " + f) } - val inInterpreter = { - try { - val interpClass = Class.forName("spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => true - } - } - // List of outer (class, object) pairs, ordered from outermost to innermost // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse @@ -274,7 +265,7 @@ private[spark] object ClosureCleaner extends Logging { // required fields from the original object. We need the parent here because the Java // language specification requires the first constructor parameter of any closure to be // its enclosing object. - val clone = instantiateClass(cls, parent, inInterpreter) + val clone = instantiateClass(cls, parent) for (fieldName <- accessedFields(cls)) { val field = cls.getDeclaredField(fieldName) field.setAccessible(true) @@ -327,9 +318,8 @@ private[spark] object ClosureCleaner extends Logging { private def instantiateClass( cls: Class[_], - enclosingObject: AnyRef, - inInterpreter: Boolean): AnyRef = { - if (!inInterpreter) { + enclosingObject: AnyRef): AnyRef = { + if (!Utils.isInInterpreter) { // This is a bona fide closure class, whose constructor has no effects // other than to set its fields, so use its constructor val cons = cls.getConstructors()(0) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 48843b4ae57c6..6a7d1fae3320e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1795,6 +1795,20 @@ private[spark] object Utils extends Logging { } } + lazy val isInInterpreter: Boolean = { + try { + val interpClass = classForName("spark.repl.Main") + interpClass.getMethod("interp").invoke(null) != null + } catch { + // Returning true seems to be a mistake. + // Currently changing it to false causes tests failures in Streaming. + // For a more detailed discussion, please, refer to + // https://github.com/apache/spark/pull/5835#issuecomment-101042271 and subsequent comments. + // Addressing this changed is tracked as https://issues.apache.org/jira/browse/SPARK-7527 + case _: ClassNotFoundException => true + } + } + /** * Return a well-formed URI for the file described by a user input string. * From 270d4b5181b95e3f1f131b1d65dde00a7e5b9d6e Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Fri, 15 May 2015 11:27:24 -0700 Subject: [PATCH 059/109] [CORE] Protect additional test vars from early GC Fix more places in which some test variables could be collected early by aggressive JVM optimization. Added a couple of comments to note where existing references are sufficient in the same test pattern. Author: Tim Ellison Closes #6187 from tellison/DefeatEarlyGC and squashes the following commits: 27329d9 [Tim Ellison] [CORE] Protect additional test vars from early GC --- .../scala/org/apache/spark/ContextCleanerSuite.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index cb30e1f4e63a1..0922a2c3599cc 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -165,6 +165,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { } // Test that GC causes RDD cleanup after dereferencing the RDD + // Note rdd is used after previous GC to avoid early collection by the JVM val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) rdd = null // Make RDD out of scope runGC() @@ -181,9 +182,9 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) } + rdd.count() // Defeat early collection by the JVM // Test that GC causes shuffle cleanup after dereferencing the RDD - rdd.count() // Defeat any early collection of rdd variable by the JVM val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope runGC() @@ -201,6 +202,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { } // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + // Note broadcast is used after previous GC to avoid early collection by the JVM val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) broadcast = null // Make broadcast variable out of scope runGC() @@ -226,7 +228,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { // the checkpoint is not cleaned by default (without the configuration set) var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) - rdd = null // Make RDD out of scope + rdd = null // Make RDD out of scope, ok if collected earlier runGC() postGCTester.assertCleanup() assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) @@ -245,6 +247,9 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { // Confirm the checkpoint directory exists assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + // Reference rdd to defeat any early collection by the JVM + rdd.count() + // Test that GC causes checkpoint data cleanup after dereferencing the RDD postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) rdd = null // Make RDD out of scope @@ -352,6 +357,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor intercept[Exception] { preGCTester.assertCleanup()(timeout(1000 millis)) } + rdd.count() // Defeat early collection by the JVM // Test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) From 8ab1450d3995b0c3ef64c5991b88c258e17bcb12 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 15 May 2015 11:30:19 -0700 Subject: [PATCH 060/109] [SPARK-5412] [DEPLOY] Cannot bind Master to a specific hostname as per the documentation Pass args to start-master.sh through to start-daemon.sh, as other scripts do, so that things like --host have effect on start-master.sh as per docs Author: Sean Owen Closes #6185 from srowen/SPARK-5412 and squashes the following commits: b3ce9da [Sean Owen] Pass args to start-master.sh through to start-daemon.sh, as other scripts do, so that things like --host have effect on start-master.sh as per docs --- sbin/start-master.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 17fff58f4f768..a7f5d5702fd80 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -22,6 +22,8 @@ sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" +ORIGINAL_ARGS="$@" + START_TACHYON=false while (( "$#" )); do @@ -53,7 +55,9 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ + --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ + $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP From ad92af9dbbd0c4e1224cca26da166382ed4f15b9 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 15 May 2015 11:54:13 -0700 Subject: [PATCH 061/109] [SPARK-7664] [WEBUI] DAG visualization: Fix incorrect link paths of DAG. In JobPage, we can jump a StagePage when we click corresponding box of DAG viz but the link path is incorrect. When we click a box like as follows ... ![screenshot_from_2015-05-15 19 24 25](https://cloud.githubusercontent.com/assets/4736016/7651528/5f7ef824-fb3c-11e4-9518-8c9ade2dff7a.png) We jump to index page. ![screenshot_from_2015-05-15 19 24 45](https://cloud.githubusercontent.com/assets/4736016/7651534/6d666274-fb3c-11e4-971c-c3f2dc2b1da2.png) Author: Kousuke Saruta Closes #6184 from sarutak/fix-link-path-of-dag-viz and squashes the following commits: faba3ba [Kousuke Saruta] Fix a incorrect link --- .../resources/org/apache/spark/ui/static/spark-dag-viz.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 8138eb0d4f390..ee48fd29a6432 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -186,8 +186,9 @@ function renderDagVizForJob(svgContainer) { var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; // Link each graph to the corresponding stage page (TODO: handle stage attempts) - var stageLink = "/stages/stage/?id=" + - stageId.replace(VizConstants.stagePrefix, "") + "&attempt=0&expandDagViz=true"; + var stageLink = $("#stage-" + stageId.replace(VizConstants.stagePrefix, "") + "-0") + .find("a") + .attr("href") + "&expandDagViz=true"; var container = svgContainer .append("a") .attr("xlink:href", stageLink) From 8e3822a0794b8b18436bd63d6859d40139a77090 Mon Sep 17 00:00:00 2001 From: ehnalis Date: Fri, 15 May 2015 12:14:02 -0700 Subject: [PATCH 062/109] [SPARK-7504] [YARN] NullPointerException when initializing SparkContext in YARN-cluster mode Added a simple checking for SparkContext. Also added two rational checking against null at AM object. Author: ehnalis Closes #6083 from ehnalis/cluster and squashes the following commits: 926bd96 [ehnalis] Moved check to SparkContext. 7c89b6e [ehnalis] Remove false line. ea2a5fe [ehnalis] [SPARK-7504] [YARN] NullPointerException when initializing SparkContext in YARN-cluster mode 4924e01 [ehnalis] [SPARK-7504] [YARN] NullPointerException when initializing SparkContext in YARN-cluster mode 39e4fa3 [ehnalis] SPARK-7504 [YARN] NullPointerException when initializing SparkContext in YARN-cluster mode 9f287c5 [ehnalis] [SPARK-7504] [YARN] NullPointerException when initializing SparkContext in YARN-cluster mode --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b59f562d05ead..af276e7b8d40c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -371,6 +371,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli throw new SparkException("An application name must be set in your configuration") } + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster + // yarn-standalone is deprecated, but still supported + if ((master == "yarn-cluster" || master == "yarn-standalone") && + !_conf.contains("spark.yarn.app.id")) { + throw new SparkException("Detected yarn-cluster mode, but isn't running on a cluster. " + + "Deployment to YARN is not supported directly by SparkContext. Please use spark-submit.") + } + if (_conf.getBoolean("spark.logConf", false)) { logInfo("Spark configuration:\n" + _conf.toDebugString) } From 9b6cf285d0b60848b01b6c7e3421e8ac850a88ab Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 15 May 2015 13:54:09 -0700 Subject: [PATCH 063/109] [SPARK-7296] Add timeline visualization for stages in the UI. This PR builds on #2342 by adding a timeline view for the Stage page, showing how tasks spend their time. With this timeline, we can understand following things of a Stage. * When/where each task ran * Total duration of each task * Proportion of the time each task spends Also, this timeline view can scrollable and zoomable. Author: Kousuke Saruta Closes #5843 from sarutak/stage-page-timeline and squashes the following commits: 4ba9604 [Kousuke Saruta] Fixed the order of legends 16bb552 [Kousuke Saruta] Removed border of legend area 2e5d605 [Kousuke Saruta] Modified warning message 16cb2e6 [Kousuke Saruta] Merge branch 'master' of https://github.com/apache/spark into stage-page-timeline 7ae328f [Kousuke Saruta] Modified code style d5f794a [Kousuke Saruta] Fixed performance issues more 64e6642 [Kousuke Saruta] Merge branch 'master' of https://github.com/apache/spark into stage-page-timeline e4a3354 [Kousuke Saruta] minor code style change 878e3b8 [Kousuke Saruta] Fixed a bug that tooltip remains b9d8f1b [Kousuke Saruta] Fixed performance issue ac8842b [Kousuke Saruta] Fixed layout 2319739 [Kousuke Saruta] Modified appearances more 81903ab [Kousuke Saruta] Modified appearances a79dcc3 [Kousuke Saruta] Modified appearance 55a390c [Kousuke Saruta] Ignored scalastyle for a line-comment 29eae3e [Kousuke Saruta] limited to longest 1000 tasks 2a9e376 [Kousuke Saruta] Minor cleanup 385b6d2 [Kousuke Saruta] Added link feature ba1ac3e [Kousuke Saruta] Fixed style 2ae8520 [Kousuke Saruta] Updated bootstrap-tooltip.js from 2.2.2 to 2.3.2 af430f1 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into stage-page-timeline e694b8e [Kousuke Saruta] Added timeline view to StagePage 8f6610c [Kousuke Saruta] Fixed conflict b587cf2 [Kousuke Saruta] initial commit 11fe67d [Kousuke Saruta] Fixed conflict 79ac03d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature a91abd3 [Kousuke Saruta] Merge branch 'master' of https://github.com/apache/spark into timeline-viewer-feature ef34a5b [Kousuke Saruta] Implement tooltip using bootstrap b09d0c5 [Kousuke Saruta] Move `stroke` and `fill` attribute of rect elements to css d3c63c8 [Kousuke Saruta] Fixed a little bit bugs a36291b [Kousuke Saruta] Merge branch 'master' of https://github.com/apache/spark into timeline-viewer-feature 28714b6 [Kousuke Saruta] Fixed highlight issue 0dc4278 [Kousuke Saruta] Addressed most of Patrics's feedbacks 8110acf [Kousuke Saruta] Added scroll limit to Job timeline 974a64a [Kousuke Saruta] Removed unused function ee7a7f0 [Kousuke Saruta] Refactored 6a91872 [Kousuke Saruta] Temporary commit 6693f34 [Kousuke Saruta] Added link to job/stage box in the timeline in order to move to corresponding row when we click 8f88222 [Kousuke Saruta] Added job/stage description aeed4b1 [Kousuke Saruta] Removed stage timeline fc1696c [Kousuke Saruta] Merge branch 'timeline-viewer-feature' of github.com:sarutak/spark into timeline-viewer-feature 999ccd4 [Kousuke Saruta] Improved scalability 0fc6a31 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature 19815ae [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature 68b7540 [Kousuke Saruta] Merge branch 'timeline-viewer-feature' of github.com:sarutak/spark into timeline-viewer-feature 52b5f0b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature dec85db [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature fcdab7d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature dab7cc1 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature 09cce97 [Kousuke Saruta] Cleanuped 16f82cf [Kousuke Saruta] Cleanuped 9fb522e [Kousuke Saruta] Cleanuped d05f2c2 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into timeline-viewer-feature e85e9aa [Kousuke Saruta] Cleanup: Added TimelineViewUtils.scala a76e569 [Kousuke Saruta] Removed unused setting in timeline-view.css 5ce1b21 [Kousuke Saruta] Added vis.min.js, vis.min.css and vis.map to .rat-exclude 082f709 [Kousuke Saruta] Added Timeline-View feature for Applications, Jobs and Stages --- .../apache/spark/ui/static/timeline-view.css | 66 +++++- .../apache/spark/ui/static/timeline-view.js | 71 +++++- .../org/apache/spark/ui/jobs/StagePage.scala | 220 +++++++++++++++++- .../org/apache/spark/ui/jobs/StagesTab.scala | 1 + 4 files changed, 348 insertions(+), 10 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index d1e6d462b836f..0f400461c5293 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -24,6 +24,65 @@ div#application-timeline, div#job-timeline { margin-top: 5px; } +#task-assignment-timeline div.legend-area { + width: 574px; +} + +#task-assignment-timeline .legend-area > svg { + width: 100%; + height: 55px; +} + +#task-assignment-timeline div.item.range { + padding: 0px; + height: 26px; + border-width: 0; +} + +.task-assignment-timeline-content { + width: 100%; +} + +.task-assignment-timeline-duration-bar { + width: 100%; + height: 26px; +} + +rect.scheduler-delay-proportion { + fill: #80B1D3; + stroke: #6B94B0; +} + +rect.deserialization-time-proportion { + fill: #FB8072; + stroke: #D26B5F; +} + +rect.shuffle-read-time-proportion { + fill: #FDB462; + stroke: #D39651; +} + +rect.executor-runtime-proportion { + fill: #B3DE69; + stroke: #95B957; +} + +rect.shuffle-write-time-proportion { + fill: #FFED6F; + stroke: #D5C65C; +} + +rect.serialization-time-proportion { + fill: #BC80BD; + stroke: #9D6B9E; +} + +rect.getting-result-time-proportion { + fill: #8DD3C7; + stroke: #75B0A6; +} + .vis.timeline { line-height: 14px; } @@ -178,6 +237,10 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { display: none; } +#task-assignment-timeline.collapsed { + display: none; +} + .control-panel { margin-bottom: 5px; } @@ -186,7 +249,8 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { margin: 0; } -span.expand-application-timeline, span.expand-job-timeline { +span.expand-application-timeline, span.expand-job-timeline, +span.expand-task-assignment-timeline { cursor: pointer; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 558beb8a5867f..e1150359bc901 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -133,6 +133,73 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }); } +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, zoomMax) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $("#task-assignment-timeline")[0] + var options = { + groupOrder: function(a, b) { + return a.value - b.value + }, + editable: false, + align: 'left', + selectable: false, + showCurrentTime: false, + min: minLaunchTime, + zoomable: false, + zoomMax: zoomMax + }; + + var taskTimeline = new vis.Timeline(container) + taskTimeline.setOptions(options); + taskTimeline.setGroups(groups); + taskTimeline.setItems(items); + + taskTimeline.on("rangechange", function(prop) { + if (currentDisplayedTooltip !== null) { + $(currentDisplayedTooltip).tooltip("hide"); + } + }); + + function getTaskIdxAndAttempt(selector) { + var taskIdxText = $(selector).attr("data-title"); + var taskIdxAndAttempt = taskIdxText.match("Task (\\d+) \\(attempt (\\d+)"); + var taskIdx = taskIdxAndAttempt[1]; + var taskAttempt = taskIdxAndAttempt[2]; + return taskIdx + "-" + taskAttempt; + } + + // If we zoom up and a box moves away when the corresponding tooltip is shown, + // the tooltip can be remain. + // So, we need to hide tooltips using another mechanism. + var currentDisplayedTooltip = null; + + $("#task-assignment-timeline").on({ + "mouseenter": function() { + var taskIdxAndAttempt = getTaskIdxAndAttempt(this); + $("#task-" + taskIdxAndAttempt).addClass("corresponding-item-hover"); + $(this).tooltip("show"); + currentDisplayedTooltip = this; + }, + "mouseleave" : function() { + var taskIdxAndAttempt = getTaskIdxAndAttempt(this); + $("#task-" + taskIdxAndAttempt).removeClass("corresponding-item-hover"); + $(this).tooltip("hide"); + currentDisplayedTooltip = null; + } + }, ".task-assignment-timeline-content"); + + setupZoomable('#task-assignment-timeline-zoom-lock', taskTimeline); + + $("span.expand-task-assignment-timeline").click(function() { + $("#task-assignment-timeline").toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-task-assignment-timeline-arrow').toggleClass('arrow-open'); + $(this).find('.expand-task-assignment-timeline-arrow').toggleClass('arrow-closed'); + }); +} + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( @@ -147,7 +214,7 @@ function setupExecutorEventAction() { } function setupZoomable(id, timeline) { - $(id + '>input[type="checkbox"]').click(function() { + $(id + ' > input[type="checkbox"]').click(function() { if (this.checked) { timeline.setOptions({zoomable: true}); } else { @@ -155,7 +222,7 @@ function setupZoomable(id, timeline) { } }); - $(id + ">span").click(function() { + $(id + " > span").click(function() { $(this).parent().find('input:checkbox').trigger('click'); }); } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8f7b1c2f09665..1a75ea62504a0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest +import scala.collection.mutable.HashSet import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -36,6 +37,35 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener + private val TIMELINE_LEGEND = { +
+ + { + val legendPairs = List(("scheduler-delay-proportion", "Scheduler Delay"), + ("deserialization-time-proportion", "Task Deserialization Time"), + ("shuffle-read-time-proportion", "Shuffle Read Time"), + ("executor-runtime-proportion", "Executor Computing Time"), + ("shuffle-write-time-proportion", "Shuffle Write Time"), + ("serialization-time-proportion", "Result Serialization TIme"), + ("getting-result-time-proportion", "Getting Result Time")) + + legendPairs.zipWithIndex.map { + case ((classAttr, name), index) => + + {name} + } + } + +
+ } + + // TODO: We should consider increasing the number of this parameter over time + // if we find that it's okay. + private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) + + def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { val parameterId = request.getParameter("id") @@ -196,7 +226,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Elem =
{acc.name}{acc.value}
{info.index} {info.taskId} { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 55169956d8304..5989f0035b270 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -25,6 +25,7 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} /** Web UI showing progress status of all stages in the given SparkContext. */ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val sc = parent.sc + val conf = parent.conf val killEnabled = parent.killEnabled val progressListener = parent.jobProgressListener val operationGraphListener = parent.operationGraphListener From 50da9e89161faa0ecdc1feb3ffee6c822a742034 Mon Sep 17 00:00:00 2001 From: qhuang Date: Fri, 15 May 2015 14:06:16 -0700 Subject: [PATCH 064/109] [SPARK-7226] [SPARKR] Support math functions in R DataFrame Author: qhuang Closes #6170 from hqzizania/master and squashes the following commits: f20c39f [qhuang] add tests units and fixes 2a7d121 [qhuang] use a function name more familiar to R users 07aa72e [qhuang] Support math functions in R DataFrame --- R/pkg/NAMESPACE | 23 ++++++++++++++++++++ R/pkg/R/column.R | 36 +++++++++++++++++++++++++++++--- R/pkg/R/generics.R | 20 ++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++++++++ 4 files changed, 100 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba29614e7b179..64ffdcffc9caf 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -59,33 +59,56 @@ exportMethods("arrange", exportClasses("Column") exportMethods("abs", + "acos", "alias", "approxCountDistinct", "asc", + "asin", + "atan", + "atan2", "avg", "cast", + "cbrt", + "ceiling", "contains", + "cos", + "cosh", "countDistinct", "desc", "endsWith", + "exp", + "expm1", + "floor", "getField", "getItem", + "hypot", "isNotNull", "isNull", "last", "like", + "log", + "log10", + "log1p", "lower", "max", "mean", "min", "n", "n_distinct", + "rint", "rlike", + "sign", + "sin", + "sinh", "sqrt", "startsWith", "substr", "sum", "sumDistinct", + "tan", + "tanh", + "toDegrees", + "toRadians", "upper") exportClasses("GroupedData") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9a68445ab451a..80e92d3105a36 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -55,12 +55,17 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or" #, "!" = "unary_$bang" + "&" = "and", "|" = "or", #, "!" = "unary_$bang" + "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct") + "first", "last", "lower", "upper", "sumDistinct", + "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", + "expm1", "floor", "log", "log10", "log1p", "rint", "sign", + "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") +binary_mathfunctions<- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -76,7 +81,11 @@ createOperator <- function(op) { if (class(e2) == "Column") { e2 <- e2@jc } - callJMethod(e1@jc, operators[[op]], e2) + if (op == "^") { + jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2) + } else { + callJMethod(e1@jc, operators[[op]], e2) + } } column(jc) }) @@ -106,11 +115,29 @@ createStaticFunction <- function(name) { setMethod(name, signature(x = "Column"), function(x) { + if (name == "ceiling") { + name <- "ceil" + } + if (name == "sign") { + name <- "signum" + } jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) column(jc) }) } +createBinaryMathfunctions <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -124,6 +151,9 @@ createMethods <- function() { for (x in functions) { createStaticFunction(x) } + for (name in binary_mathfunctions) { + createBinaryMathfunctions(name) + } } createMethods() diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6d2bfb1181e5a..a23d3b217b2fd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -552,6 +552,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) +#' @rdname column +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) @@ -575,6 +579,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -603,6 +611,10 @@ setGeneric("n", function(x) { standardGeneric("n") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @rdname column +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) @@ -615,6 +627,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +#' @rdname column +#' @export +setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) + +#' @rdname column +#' @export +setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) + #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1109e8fdba3fd..3e5658eb5b24b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,7 @@ test_that("column operators", { c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 }) test_that("column functions", { @@ -538,6 +539,29 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) + c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) + c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) + c9 <- toDegrees(c) + toRadians(c) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- jsonFile(sqlCtx, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) }) test_that("string operators", { From 6e77105e11ff81bfd84561f4e1121111f686df21 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 15 May 2015 14:57:29 -0700 Subject: [PATCH 065/109] [SPARK-7677] [STREAMING] Add Kafka modules to the 2.11 build. This is somewhat related to [SPARK-6154](https://issues.apache.org/jira/browse/SPARK-6154), though it only touches Kafka, not the jline dependency for thriftserver. I tested this locally on 2.11 (./run-tests) and everything looked good (I had to disable mima, because `MimaBuild` harcodes 2.10 for the previous version -- that's another PR). Author: Iulian Dragos Closes #6149 from dragos/issue/spark-2.11-kafka and squashes the following commits: aa15d99 [Iulian Dragos] Add Kafka modules to the 2.11 build. --- pom.xml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 91d1d843c762a..86aa0a9fa134c 100644 --- a/pom.xml +++ b/pom.xml @@ -107,6 +107,8 @@ examples repl launcher + external/kafka + external/kafka-assembly @@ -1757,10 +1759,6 @@ ${scala.version} org.scala-lang - - external/kafka - external/kafka-assembly - From c8696337e2a5878f3171eb574c0a1365d45814c9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 May 2015 15:05:04 -0700 Subject: [PATCH 066/109] [SPARK-7556] [ML] [DOC] Add user guide for spark.ml Binarizer, including Scala, Java and Python examples JIRA: https://issues.apache.org/jira/browse/SPARK-7556 Author: Liang-Chi Hsieh Closes #6116 from viirya/binarizer_doc and squashes the following commits: 40cb677 [Liang-Chi Hsieh] Better print out. 5b7ef1d [Liang-Chi Hsieh] Make examples more clear. 1bf9c09 [Liang-Chi Hsieh] For comments. 6cf8cba [Liang-Chi Hsieh] Add user guide for Binarizer. --- docs/ml-features.md | 84 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 0cbebcb739b14..5df61dd36a070 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -183,6 +183,90 @@ for words_label in wordsDataFrame.select("words", "label").take(3): +## Binarizer + +Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. + +A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. + +
+
+{% highlight scala %} +import org.apache.spark.ml.feature.Binarizer +import org.apache.spark.sql.DataFrame + +val data = Array( + (0, 0.1), + (1, 0.8), + (2, 0.2) +) +val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + +val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + +val binarizedDataFrame = binarizer.transform(dataFrame) +val binarizedFeatures = binarizedDataFrame.select("binarized_feature") +binarizedFeatures.collect().foreach(println) +{% endhighlight %} +
+ +
+{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); +Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); +DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); +DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); +for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); +} +{% endhighlight %} +
+ +
+{% highlight python %} +from pyspark.ml.feature import Binarizer + +continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) +], ["label", "feature"]) +binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") +binarizedDataFrame = binarizer.transform(continuousDataFrame) +binarizedFeatures = binarizedDataFrame.select("binarized_feature") +for binarized_feature, in binarizedFeatures.collect(): + print binarized_feature +{% endhighlight %} +
+
# Feature Selectors From e74545647684b3047248ca3cfee894ac5378dead Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Fri, 15 May 2015 17:45:14 -0700 Subject: [PATCH 067/109] [SPARK-7676] Bug fix and cleanup of stage timeline view cc pwendell sarutak This commit cleans up some unnecessary code, eliminates the feature where when you mouse-over a box in the timeline, the corresponding task is highlighted in the table (because that feature is only useful in the rare case when you have a very small number of tasks, in which case it's easy to figure out the mapping anyway), and fixes a bug where nothing shows up if you try to visualize a stage with only 1 task. Author: Kay Ousterhout Closes #6202 from kayousterhout/SPARK-7676 and squashes the following commits: dfd29d4 [Kay Ousterhout] [SPARK-7676] Bug fix and cleanup of stage timeline view --- .../apache/spark/ui/static/timeline-view.js | 48 +++++++------------ .../org/apache/spark/ui/jobs/StagePage.scala | 19 ++------ 2 files changed, 20 insertions(+), 47 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index e1150359bc901..604c29994145a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -133,7 +133,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }); } -function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, zoomMax) { +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); var container = $("#task-assignment-timeline")[0] @@ -146,8 +146,8 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, zo selectable: false, showCurrentTime: false, min: minLaunchTime, - zoomable: false, - zoomMax: zoomMax + max: maxFinishTime, + zoomable: false }; var taskTimeline = new vis.Timeline(container) @@ -155,48 +155,32 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, zo taskTimeline.setGroups(groups); taskTimeline.setItems(items); - taskTimeline.on("rangechange", function(prop) { - if (currentDisplayedTooltip !== null) { - $(currentDisplayedTooltip).tooltip("hide"); - } - }); - - function getTaskIdxAndAttempt(selector) { - var taskIdxText = $(selector).attr("data-title"); - var taskIdxAndAttempt = taskIdxText.match("Task (\\d+) \\(attempt (\\d+)"); - var taskIdx = taskIdxAndAttempt[1]; - var taskAttempt = taskIdxAndAttempt[2]; - return taskIdx + "-" + taskAttempt; - } - - // If we zoom up and a box moves away when the corresponding tooltip is shown, - // the tooltip can be remain. - // So, we need to hide tooltips using another mechanism. + // If a user zooms while a tooltip is displayed, the user may zoom such that the cursor is no + // longer over the task that the tooltip corresponds to. So, when a user zooms, we should hide + // any currently displayed tooltips. var currentDisplayedTooltip = null; - $("#task-assignment-timeline").on({ "mouseenter": function() { - var taskIdxAndAttempt = getTaskIdxAndAttempt(this); - $("#task-" + taskIdxAndAttempt).addClass("corresponding-item-hover"); - $(this).tooltip("show"); currentDisplayedTooltip = this; }, - "mouseleave" : function() { - var taskIdxAndAttempt = getTaskIdxAndAttempt(this); - $("#task-" + taskIdxAndAttempt).removeClass("corresponding-item-hover"); - $(this).tooltip("hide"); + "mouseleave": function() { currentDisplayedTooltip = null; } }, ".task-assignment-timeline-content"); + taskTimeline.on("rangechange", function(prop) { + if (currentDisplayedTooltip !== null) { + $(currentDisplayedTooltip).tooltip("hide"); + } + }); - setupZoomable('#task-assignment-timeline-zoom-lock', taskTimeline); + setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); $("span.expand-task-assignment-timeline").click(function() { - $("#task-assignment-timeline").toggleClass('collapsed'); + $("#task-assignment-timeline").toggleClass("collapsed"); // Switch the class of the arrow from open to closed. - $(this).find('.expand-task-assignment-timeline-arrow').toggleClass('arrow-open'); - $(this).find('.expand-task-assignment-timeline-arrow').toggleClass('arrow-closed'); + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); }); } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1a75ea62504a0..31e2e7fba9783 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -521,21 +521,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val host = taskInfo.host executorsSet += ((executorId, host)) - val classNameByStatus = { - if (taskInfo.successful) { - "succeeded" - } else if (taskInfo.failed) { - "failed" - } else if (taskInfo.running) { - "running" - } - } - val launchTime = taskInfo.launchTime val finishTime = if (!taskInfo.running) taskInfo.finishTime else currentTime val totalExecutionTime = finishTime - launchTime minLaunchTime = launchTime.min(minLaunchTime) - maxFinishTime = launchTime.max(maxFinishTime) + maxFinishTime = finishTime.max(maxFinishTime) def toProportion(time: Long) = (time.toDouble / totalExecutionTime * 100).toLong @@ -583,7 +573,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val timelineObject = s""" { - 'className': 'task task-assignment-timeline-object $classNameByStatus', + 'className': 'task task-assignment-timeline-object', 'group': '$executorId', 'content': '
Event Timeline @@ -671,7 +660,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
++ } @@ -748,7 +737,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") -
{info.index} {info.taskId} { From 2c04c8a1aed34cce420b3d30d9e885daa6e03d74 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 15 May 2015 18:06:01 -0700 Subject: [PATCH 068/109] [SPARK-7563] OutputCommitCoordinator.stop() should only run on the driver This fixes a bug where an executor that exits can cause the driver's OutputCommitCoordinator to stop. To fix this, we use an `isDriver` flag and check it in `stop()`. See https://issues.apache.org/jira/browse/SPARK-7563 for more details. Author: Josh Rosen Closes #6197 from JoshRosen/SPARK-7563 and squashes the following commits: 04b2cc5 [Josh Rosen] [SPARK-7563] OutputCommitCoordinator.stop() should only be executed on the driver --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/scheduler/OutputCommitCoordinator.scala | 10 ++++++---- .../spark/scheduler/OutputCommitCoordinatorSuite.scala | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a5d831c7e68ad..327114542880d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -379,7 +379,7 @@ object SparkEnv extends Logging { } val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - new OutputCommitCoordinator(conf) + new OutputCommitCoordinator(conf, isDriver) } val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 0b1d47cff3746..8321037cdc026 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -38,7 +38,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. */ -private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { +private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging { // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None @@ -129,9 +129,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { } def stop(): Unit = synchronized { - coordinatorRef.foreach(_ send StopCoordinator) - coordinatorRef = None - authorizedCommittersByStage.clear() + if (isDriver) { + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None + authorizedCommittersByStage.clear() + } } // Marked private[scheduler] instead of private so this can be mocked in tests diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index cf97707946706..7078a7a12232a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -81,7 +81,7 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - outputCommitCoordinator = spy(new OutputCommitCoordinator(conf)) + outputCommitCoordinator = spy(new OutputCommitCoordinator(conf, isDriver = true)) // Use Mockito.spy() to maintain the default infrastructure everywhere else. // This mocking allows us to control the coordinator responses in test cases. SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) From cc12a86fb049f2be1f45baf461d202ec356ccf8f Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 15 May 2015 19:33:20 -0700 Subject: [PATCH 069/109] [SPARK-7575] [ML] [DOC] Example code for OneVsRest Java and Scala examples for OneVsRest. Fixes the base classifier to be Logistic Regression and accepts the configuration parameters of the base classifier. Author: Ram Sriharsha Closes #6115 from harsha2010/SPARK-7575 and squashes the following commits: 87ad3c7 [Ram Sriharsha] extra line f5d9891 [Ram Sriharsha] Merge branch 'master' into SPARK-7575 7076084 [Ram Sriharsha] cleanup dfd660c [Ram Sriharsha] cleanup 8703e4f [Ram Sriharsha] update doc cb23995 [Ram Sriharsha] fix commandline options for JavaOneVsRestExample 69e91f8 [Ram Sriharsha] cleanup 7f4e127 [Ram Sriharsha] cleanup d4c40d0 [Ram Sriharsha] Code Review fixes 461eb38 [Ram Sriharsha] cleanup e0106d9 [Ram Sriharsha] Fix typo 935cf56 [Ram Sriharsha] Try to match Java and Scala Example Commandline options 5323ff9 [Ram Sriharsha] cleanup 196a59a [Ram Sriharsha] cleanup 6adfa0c [Ram Sriharsha] Style Fix 8cfc5d5 [Ram Sriharsha] [SPARK-7575] Example code for OneVsRest --- .../examples/ml/JavaOneVsRestExample.java | 236 ++++++++++++++++++ .../spark/examples/ml/OneVsRestExample.scala | 185 ++++++++++++++ 2 files changed, 421 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java new file mode 100644 index 0000000000000..75063dbf800d8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.commons.cli.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; +import org.apache.spark.ml.util.MetadataUtils; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + *
+ * bin/run-example ml.JavaOneVsRestExample [options]
+ * 
+ */ +public class JavaOneVsRestExample { + + private static class Params { + String input; + String testInput = null; + Integer maxIter = 100; + double tol = 1E-6; + boolean fitIntercept = true; + Double regParam = null; + Double elasticNetParam = null; + double fracTest = 0.2; + } + + public static void main(String[] args) { + // parse the arguments + Params params = parse(args); + SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // configure the base classifier + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept); + + if (params.regParam != null) { + classifier.setRegParam(params.regParam); + } + if (params.elasticNetParam != null) { + classifier.setElasticNetParam(params.elasticNetParam); + } + + // instantiate the One Vs Rest Classifier + OneVsRest ovr = new OneVsRest().setClassifier(classifier); + + String input = params.input; + RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); + RDD train; + RDD test; + + // compute the train/ test split: if testInput is not provided use part of input + String testInput = params.testInput; + if (testInput != null) { + train = inputData; + // compute the number of features in the training set. + int numFeatures = inputData.first().features().size(); + test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + } else { + double f = params.fracTest; + RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + train = tmp[0]; + test = tmp[1]; + } + + // train the multiclass model + DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); + OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + + // score the model on test data + DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); + DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + .select("prediction", "label"); + + // obtain metrics + MulticlassMetrics metrics = new MulticlassMetrics(predictions); + StructField predictionColSchema = predictions.schema().apply("prediction"); + Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); + + // compute the false positive rate per label + StringBuilder results = new StringBuilder(); + results.append("label\tfpr\n"); + for (int label = 0; label < numClasses; label++) { + results.append(label); + results.append("\t"); + results.append(metrics.falsePositiveRate((double) label)); + results.append("\n"); + } + + Matrix confusionMatrix = metrics.confusionMatrix(); + // output the Confusion Matrix + System.out.println("Confusion Matrix"); + System.out.println(confusionMatrix); + System.out.println(); + System.out.println(results); + + jsc.stop(); + } + + private static Params parse(String[] args) { + Options options = generateCommandlineOptions(); + CommandLineParser parser = new PosixParser(); + Params params = new Params(); + + try { + CommandLine cmd = parser.parse(options, args); + String value; + if (cmd.hasOption("input")) { + params.input = cmd.getOptionValue("input"); + } + if (cmd.hasOption("maxIter")) { + value = cmd.getOptionValue("maxIter"); + params.maxIter = Integer.parseInt(value); + } + if (cmd.hasOption("tol")) { + value = cmd.getOptionValue("tol"); + params.tol = Double.parseDouble(value); + } + if (cmd.hasOption("fitIntercept")) { + value = cmd.getOptionValue("fitIntercept"); + params.fitIntercept = Boolean.parseBoolean(value); + } + if (cmd.hasOption("regParam")) { + value = cmd.getOptionValue("regParam"); + params.regParam = Double.parseDouble(value); + } + if (cmd.hasOption("elasticNetParam")) { + value = cmd.getOptionValue("elasticNetParam"); + params.elasticNetParam = Double.parseDouble(value); + } + if (cmd.hasOption("testInput")) { + value = cmd.getOptionValue("testInput"); + params.testInput = value; + } + if (cmd.hasOption("fracTest")) { + value = cmd.getOptionValue("fracTest"); + params.fracTest = Double.parseDouble(value); + } + + } catch (ParseException e) { + printHelpAndQuit(options); + } + return params; + } + + private static Options generateCommandlineOptions() { + Option input = OptionBuilder.withArgName("input") + .hasArg() + .isRequired() + .withDescription("input path to labeled examples. This path must be specified") + .create("input"); + Option testInput = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("input path to test examples") + .create("testInput"); + Option fracTest = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("fraction of data to hold out for testing." + + " If given option testInput, this option is ignored. default: 0.2") + .create("fracTest"); + Option maxIter = OptionBuilder.withArgName("maxIter") + .hasArg() + .withDescription("maximum number of iterations for Logistic Regression. default:100") + .create("maxIter"); + Option tol = OptionBuilder.withArgName("tol") + .hasArg() + .withDescription("the convergence tolerance of iterations " + + "for Logistic Regression. default: 1E-6") + .create("tol"); + Option fitIntercept = OptionBuilder.withArgName("fitIntercept") + .hasArg() + .withDescription("fit intercept for logistic regression. default true") + .create("fitIntercept"); + Option regParam = OptionBuilder.withArgName( "regParam" ) + .hasArg() + .withDescription("the regularization parameter for Logistic Regression.") + .create("regParam"); + Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) + .hasArg() + .withDescription("the ElasticNet mixing parameter for Logistic Regression.") + .create("elasticNetParam"); + + Options options = new Options() + .addOption(input) + .addOption(testInput) + .addOption(fracTest) + .addOption(maxIter) + .addOption(tol) + .addOption(fitIntercept) + .addOption(regParam) + .addOption(elasticNetParam); + + return options; + } + + private static void printHelpAndQuit(Options options) { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp("JavaOneVsRestExample", options); + System.exit(-1); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala new file mode 100644 index 0000000000000..b99d0a1246011 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} + +import scopt.OptionParser + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + * {{{ + * ./bin/run-example ml.OneVsRestExample [options] + * }}} + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object OneVsRestExample { + + case class Params private[ml] ( + input: String = null, + testInput: Option[String] = None, + maxIter: Int = 100, + tol: Double = 1E-6, + fitIntercept: Boolean = true, + regParam: Option[Double] = None, + elasticNetParam: Option[Double] = None, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("OneVsRest Example") { + head("OneVsRest Example: multiclass to binary reduction using OneVsRest") + opt[String]("input") + .text("input path to labeled examples. This path must be specified") + .required() + .action((x, c) => c.copy(input = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text("input path to test dataset. If given, option fracTest is ignored") + .action((x,c) => c.copy(testInput = Some(x))) + opt[Int]("maxIter") + .text(s"maximum number of iterations for Logistic Regression." + + s" default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations for Logistic Regression." + + s" default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Boolean]("fitIntercept") + .text(s"fit intercept for Logistic Regression." + + s" default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("regParam") + .text(s"the regularization parameter for Logistic Regression.") + .action((x,c) => c.copy(regParam = Some(x))) + opt[Double]("elasticNetParam") + .text(s"the ElasticNet mixing parameter for Logistic Regression.") + .action((x,c) => c.copy(elasticNetParam = Some(x))) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") + val sc = new SparkContext(conf) + val inputData = MLUtils.loadLibSVMFile(sc, params.input) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // compute the train/test split: if testInput is not provided use part of input. + val data = params.testInput match { + case Some(t) => { + // compute the number of features in the training set. + val numFeatures = inputData.first().features.size + val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) + Array[RDD[LabeledPoint]](inputData, testData) + } + case None => { + val f = params.fracTest + inputData.randomSplit(Array(1 - f, f), seed = 12345) + } + } + val Array(train, test) = data.map(_.toDF().cache()) + + // instantiate the base classifier + val classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept) + + // Set regParam, elasticNetParam if specified in params + params.regParam.foreach(classifier.setRegParam) + params.elasticNetParam.foreach(classifier.setElasticNetParam) + + // instantiate the One Vs Rest Classifier. + + val ovr = new OneVsRest() + ovr.setClassifier(classifier) + + // train the multiclass model. + val (trainingDuration, ovrModel) = time(ovr.fit(train)) + + // score the model on test data. + val (predictionDuration, predictions) = time(ovrModel.transform(test)) + + // evaluate the model + val predictionsAndLabels = predictions.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val metrics = new MulticlassMetrics(predictionsAndLabels) + + val confusionMatrix = metrics.confusionMatrix + + // compute the false positive rate per label + val predictionColSchema = predictions.schema("prediction") + val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get + val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + + println(s" Training Time ${trainingDuration} sec\n") + + println(s" Prediction Time ${predictionDuration} sec\n") + + println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") + + println("label\tfpr") + + println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + + sc.stop() + } + + private def time[R](block: => R): (Long, R) = { + val t0 = System.nanoTime() + val result = block // call-by-name + val t1 = System.nanoTime() + (NANO.toSeconds(t1 - t0), result) + } +} From adfd366814499c0540a15dd6017091ba8c0f05da Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 15 May 2015 20:05:26 -0700 Subject: [PATCH 070/109] [SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python Author: Davies Liu Closes #6206 from davies/sql_type and squashes the following commits: 33d6860 [Davies Liu] [SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python --- python/pyspark/sql/_types.py | 76 ++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 629c3a94513b8..9e7e9f04bc35d 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -73,56 +73,74 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle -class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" +class DataTypeSingleton(type): + """Metaclass for DataType""" _instances = {} def __call__(cls): if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + cls._instances[cls] = super(DataTypeSingleton, cls).__call__() return cls._instances[cls] -class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" +class NullType(DataType): + """Null type. - __metaclass__ = PrimitiveTypeSingleton + The data type representing None, used for the types that cannot be inferred. + """ + __metaclass__ = DataTypeSingleton -class NullType(PrimitiveType): - """Null type. - The data type representing None, used for the types that cannot be inferred. +class AtomicType(DataType): + """An internal type used to represent everything that is not + null, UDTs, arrays, structs, and maps.""" + + __metaclass__ = DataTypeSingleton + + +class NumericType(AtomicType): + """Numeric data types. """ -class StringType(PrimitiveType): +class IntegralType(NumericType): + """Integral data types. + """ + + +class FractionalType(NumericType): + """Fractional data types. + """ + + +class StringType(AtomicType): """String data type. """ -class BinaryType(PrimitiveType): +class BinaryType(AtomicType): """Binary (byte array) data type. """ -class BooleanType(PrimitiveType): +class BooleanType(AtomicType): """Boolean data type. """ -class DateType(PrimitiveType): +class DateType(AtomicType): """Date (datetime.date) data type. """ -class TimestampType(PrimitiveType): +class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ -class DecimalType(DataType): +class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. """ @@ -150,31 +168,31 @@ def __repr__(self): return "DecimalType()" -class DoubleType(PrimitiveType): +class DoubleType(FractionalType): """Double data type, representing double precision floats. """ -class FloatType(PrimitiveType): +class FloatType(FractionalType): """Float data type, representing single precision floats. """ -class ByteType(PrimitiveType): +class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' -class IntegerType(PrimitiveType): +class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' -class LongType(PrimitiveType): +class LongType(IntegralType): """Long data type, i.e. a signed 64-bit integer. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], @@ -184,7 +202,7 @@ def simpleString(self): return 'bigint' -class ShortType(PrimitiveType): +class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): @@ -426,11 +444,9 @@ def __eq__(self, other): return type(self) == type(other) -_all_primitive_types = dict((v.typeName(), v) - for v in list(globals().values()) - if (type(v) is type or type(v) is PrimitiveTypeSingleton) - and v.__base__ == PrimitiveType) - +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -444,7 +460,7 @@ def _parse_datatype_json_string(json_string): ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype - >>> for cls in _all_primitive_types.values(): + >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. @@ -494,8 +510,8 @@ def _parse_datatype_json_string(json_string): def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if json_value in _all_atomic_types.keys(): + return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): @@ -1125,7 +1141,7 @@ def Dict(d): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - # no wrapper for primitive types + # no wrapper for atomic types return lambda x: x class Row(tuple): From d7b69946cb21cd2781c9ad3e691e54b28efbbf3d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 15 May 2015 20:09:15 -0700 Subject: [PATCH 071/109] [SPARK-7543] [SQL] [PySpark] split dataframe.py into multiple files dataframe.py is splited into column.py, group.py and dataframe.py: ``` 360 column.py 1223 dataframe.py 183 group.py ``` Author: Davies Liu Closes #6201 from davies/split_df and squashes the following commits: fc8f5ab [Davies Liu] split dataframe.py into multiple files --- python/pyspark/sql/__init__.py | 5 +- python/pyspark/sql/column.py | 360 +++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 449 +------------------------------- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 183 +++++++++++++ python/run-tests | 2 + 6 files changed, 552 insertions(+), 449 deletions(-) create mode 100644 python/pyspark/sql/column.py create mode 100644 python/pyspark/sql/group.py diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 7192c89b3dc7f..19805e291e91b 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -55,8 +55,9 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions -from pyspark.sql.dataframe import DataFrameStatFunctions +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.group import GroupedData __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py new file mode 100644 index 0000000000000..fc7ad674daa5b --- /dev/null +++ b/python/pyspark/sql/column.py @@ -0,0 +1,360 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version >= '3': + basestring = str + long = int + +from pyspark.context import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.types import * + +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("apply") + + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + >>> df.select(df.l[0], df.d["key"]).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + +----+ + |r[b]| + +----+ + | b| + +----+ + >>> df.select(df.r.a).show() + +----+ + |r[a]| + +----+ + | 1| + +----+ + """ + return self[name] + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, *alias): + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + + @ignore_unicode_prefix + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + @ignore_unicode_prefix + def between(self, lowerBound, upperBound): + """ A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + """ + return (self >= lowerBound) & (self <= upperBound) + + @ignore_unicode_prefix + def when(self, condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(value) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + import pyspark.sql.column + globs = pyspark.sql.column.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.column, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2ed95ac8e2505..96d927b9ba35c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,17 +25,15 @@ else: from itertools import imap as map -from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import * from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.column import Column, _to_seq, _to_java_column - -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -757,6 +755,7 @@ def groupBy(self, *cols): [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ jdf = self._jdf.groupBy(self._jcols(*cols)) + from pyspark.sql.group import GroupedData return GroupedData(jdf, self.sql_ctx) def agg(self, *exprs): @@ -1141,169 +1140,6 @@ class SchemaRDD(DataFrame): """ -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -def df_varargs_api(f): - def _api(self, *args): - name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by :func:`DataFrame.groupBy`. - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - @ignore_unicode_prefix - def agg(self, *exprs): - """Compute aggregates and returns the result as a :class:`DataFrame`. - - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. - - If ``exprs`` is a single :class:`dict` mapping from string to string, then the key - is the column to perform aggregation on, and the value is the aggregate function. - - Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - - :param exprs: a dict mapping from column name (string) to aggregate functions (string), - or a list of :class:`Column`. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] - - >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """Counts the number of records for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @df_varargs_api - def mean(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def avg(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def max(self, *cols): - """Computes the max value for each numeric columns for each group. - - >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] - >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] - """ - - @df_varargs_api - def min(self, *cols): - """Computes the min value for each numeric column for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] - >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] - """ - - @df_varargs_api - def sum(self, *cols): - """Compute the sum for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] - >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] - """ - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.functions.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.functions.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _to_seq(sc, cols, converter=None): - """ - Convert a list of Column (or names) into a JVM Seq of Column. - - An optional `converter` could be used to convert items in `cols` - into JVM Column objects. - """ - if converter: - cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toSeq(cols) - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. @@ -1311,282 +1147,6 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc) - _.__doc__ = doc - return _ - - -def _func_op(name, doc=''): - def _(self): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.functions, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -class Column(object): - - """ - A column in a DataFrame. - - :class:`Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc): - self._jc = jc - - # arithmetic operators - __neg__ = _func_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rtruediv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _func_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("apply") - - # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") - - def getItem(self, key): - """An expression that gets an item at position `ordinal` out of a list, - or gets an item by key out of a dict. - - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) - >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - >>> df.select(df.l[0], df.d["key"]).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - """ - return self[key] - - def getField(self, name): - """An expression that gets a field by name in a StructField. - - >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() - >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ - >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ - """ - return self[name] - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - return self.getField(item) - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - @ignore_unicode_prefix - def substr(self, startPos, length): - """ - Return a :class:`Column` which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.select(df.name.substr(1, 3).alias("col")).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc) - - __getslice__ = substr - - @ignore_unicode_prefix - def inSet(self, *cols): - """ A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - """ - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] - sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) - return Column(jc) - - # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, *alias): - """Returns this column aliased with a new name or names (in the case of expressions that - return more than one column, such as explode). - - >>> df.select(df.age.alias("age2")).collect() - [Row(age2=2), Row(age2=5)] - """ - - if len(alias) == 1: - return Column(getattr(self._jc, "as")(alias[0])) - else: - sc = SparkContext._active_spark_context - return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) - - @ignore_unicode_prefix - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - else: - raise TypeError("unexpected type: %s" % type(dataType)) - return Column(jc) - - @ignore_unicode_prefix - def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this - expression is between the given columns. - """ - return (self >= lowerBound) & (self <= upperBound) - - @ignore_unicode_prefix - def when(self, condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param condition: a boolean :class:`Column` expression. - :param value: a literal value, or a :class:`Column` expression. - - """ - sc = SparkContext._active_spark_context - if not isinstance(condition, Column): - raise TypeError("condition should be a Column") - v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) - return Column(jc) - - @ignore_unicode_prefix - def otherwise(self, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param value: a literal value, or a :class:`Column` expression. - """ - v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) - return Column(jc) - - def __repr__(self): - return 'Column<%s>' % self._jc.toString().encode('utf8') - - class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. """ @@ -1646,9 +1206,6 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6cd6974b0e5bb..8d0e766ecd3b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq __all__ = [ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py new file mode 100644 index 0000000000000..9f7c743c051d3 --- /dev/null +++ b/python/pyspark/sql/group.py @@ -0,0 +1,183 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import * + +__all__ = ["GroupedData"] + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age)=5, MAX(height)=85)] + """ + + @df_varargs_api + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age)=2, MIN(height)=80)] + """ + + @df_varargs_api + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age)=7, SUM(height)=165)] + """ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.group + globs = pyspark.sql.group.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.group, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index f2757a3967e81..ffde2fb24b369 100755 --- a/python/run-tests +++ b/python/run-tests @@ -72,7 +72,9 @@ function run_sql_tests() { echo "Run sql tests ..." run_test "pyspark/sql/_types.py" run_test "pyspark/sql/context.py" + run_test "pyspark/sql/column.py" run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/group.py" run_test "pyspark/sql/functions.py" run_test "pyspark/sql/tests.py" } From deb411335a09b91eb1f75421d77e1c3686719621 Mon Sep 17 00:00:00 2001 From: AiHe Date: Fri, 15 May 2015 20:42:35 -0700 Subject: [PATCH 072/109] [SPARK-7473] [MLLIB] Add reservoir sample in RandomForest reservoir feature sample by using existing api Author: AiHe Closes #5988 from AiHe/reservoir and squashes the following commits: e7a41ac [AiHe] remove non-robust testing case 28ffb9a [AiHe] set seed as rng.nextLong 37459e1 [AiHe] set fixed seed 1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest --- .../scala/org/apache/spark/mllib/tree/RandomForest.scala | 6 +++--- .../org/apache/spark/mllib/tree/RandomForestSuite.scala | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c7d9c95..b347c450c1aa8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc98486862..4ed66953cb628 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ From 578bfeeff514228f6fd4b07a536815fbb3510f7e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 May 2015 22:00:31 -0700 Subject: [PATCH 073/109] [SPARK-7654][SQL] DataFrameReader and DataFrameWriter for input/output API This patch introduces DataFrameWriter and DataFrameReader. DataFrameReader interface, accessible through SQLContext.read, contains methods that create DataFrames. These methods used to reside in SQLContext. Example usage: ```scala sqlContext.read.json("...") sqlContext.read.parquet("...") ``` DataFrameWriter interface, accessible through DataFrame.write, implements a builder pattern to avoid the proliferation of options in writing DataFrame out. It currently implements: - mode - format (e.g. "parquet", "json") - options (generic options passed down into data sources) - partitionBy (partitioning columns) Example usage: ```scala df.write.mode("append").format("json").partitionBy("date").saveAsTable("myJsonTable") ``` TODO: - [ ] Documentation update - [ ] Move JDBC into reader / writer? - [ ] Deprecate the old interfaces - [ ] Move the generic load interface into reader. - [ ] Update example code and documentation Author: Reynold Xin Closes #6175 from rxin/reader-writer and squashes the following commits: b146c95 [Reynold Xin] Deprecation of old APIs. bd8abdf [Reynold Xin] Fixed merge conflict. 26abea2 [Reynold Xin] Added general load methods. 244fbec [Reynold Xin] Added equivalent to example. 4f15d92 [Reynold Xin] Added documentation for partitionBy. 7e91611 [Reynold Xin] [SPARK-7654][SQL] DataFrameReader and DataFrameWriter for input/output API. --- .../spark/examples/sql/JavaSparkSQL.java | 4 +- .../spark/examples/mllib/DatasetExample.scala | 2 +- .../spark/examples/sql/RDDRelation.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 172 +++------- .../apache/spark/sql/DataFrameReader.scala | 218 ++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 198 +++++++++++ .../org/apache/spark/sql/SQLContext.scala | 158 +++------ .../spark/sql/parquet/ParquetTest.scala | 8 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 17 +- .../spark/sql/UserDefinedTypeSuite.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 50 +-- .../sql/parquet/ParquetFilterSuite.scala | 6 +- .../spark/sql/parquet/ParquetIOSuite.scala | 41 ++- .../ParquetPartitionDiscoverySuite.scala | 16 +- .../sources/CreateTableAsSelectSuite.scala | 2 +- .../spark/sql/sources/InsertSuite.scala | 10 +- .../spark/sql/sources/SaveLoadSuite.scala | 26 +- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../spark/sql/hive/HiveParquetSuite.scala | 8 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 18 +- .../apache/spark/sql/hive/parquetSuites.scala | 16 +- .../sql/sources/hadoopFsRelationSuites.scala | 321 ++++++++---------- 24 files changed, 772 insertions(+), 541 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8159ffbe2d269..173633ce059e3 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -99,7 +99,7 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index e943d6c889fab..c95cca7d656e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -106,7 +106,7 @@ object DatasetExample { df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) + val newDataset = sqlContext.read.parquet(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 6331d1c0060f8..acc89199d5849 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -61,7 +61,7 @@ object RDDRelation { df.saveAsParquetFile("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. - val parquetFile = sqlContext.parquetFile("pair.parquet") + val parquetFile = sqlContext.read.parquet("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2e20c3d3f4ed2..55ef357a99f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1289,6 +1289,16 @@ class DataFrame private[sql]( sqlContext.registerDataFrameAsTable(this, tableName) } + /** + * :: Experimental :: + * Interface for saving the content of the [[DataFrame]] out into external storage. + * + * @group output + * @since 1.4.0 + */ + @Experimental + def write: DataFrameWriter = new DataFrameWriter(this) + /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] @@ -1296,16 +1306,16 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ + @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { if (sqlContext.conf.parquetUseDataSourceApi) { - save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) + write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } else { sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd } } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame. * It will use the default data source configured by spark.sql.sources.default. * This will fail if the table already exists. @@ -1320,13 +1330,12 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { - saveAsTable(tableName, SaveMode.ErrorIfExists) + write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame, using the default data source * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. * @@ -1340,20 +1349,18 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { // If table already exists and the save mode is Append, // we will just call insertInto to append the contents of this DataFrame. insertInto(tableName, overwrite = false) } else { - val dataSourceName = sqlContext.conf.defaultDataSourceName - saveAsTable(tableName, dataSourceName, mode) + write.mode(mode).saveAsTable(tableName) } } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source and a set of options, * using [[SaveMode.ErrorIfExists]] as the save mode. @@ -1368,9 +1375,9 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { - saveAsTable(tableName, source, SaveMode.ErrorIfExists) + write.format(source).saveAsTable(tableName) } /** @@ -1388,13 +1395,12 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - saveAsTable(tableName, source, mode, Map.empty[String, String]) + write.format(source).mode(mode).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * @@ -1408,40 +1414,17 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap, partitionColumns) + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: * (Scala-specific) * Creates a table from the the contents of this DataFrame based on a given data source, * [[SaveMode]] specified by mode, and a set of options. @@ -1456,167 +1439,88 @@ class DataFrame private[sql]( * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: Map[String, String]): Unit = { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - Array.empty[String], - mode, - options, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - sqlContext.executePlan( - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitionColumns.toArray, - mode, - options, - logicalPlan)).toRdd - } - - /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path, * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { - save(path, SaveMode.ErrorIfExists) + write.save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - save(path, dataSourceName, mode) + write.mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { - save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + write.format(source).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { - save(source, mode, Map("path" -> path)) + write.format(source).mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - save(source, mode, options.toMap) + write.format(source).mode(mode).options(options).save() } /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - save(source, mode, options.toMap, partitionColumns) - } - - /** - * :: Experimental :: * (Scala-specific) * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output * @since 1.3.0 */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: Map[String, String]): Unit = { - ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this) - } - - /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this) + write.format(source).mode(mode).options(options).save() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala new file mode 100644 index 0000000000000..4d63faad6fb7c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -0,0 +1,218 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.json.{JsonRDD, JSONRelation} +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, + * key-value stores, etc). + * + * @since 1.4.0 + */ +@Experimental +class DataFrameReader private[sql](sqlContext: SQLContext) { + + /** + * Specifies the input data source format. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can + * skip the schema inference step, and thus speed up data loading. + * + * @since 1.4.0 + */ + def schema(schema: StructType): DataFrameReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameReader = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Specifies the input partitioning. If specified, the underlying data source does not need to + * discover the data partitioning scheme, and thus can speed up very large inputs. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataFrameReader = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + + /** + * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external + * key-value stores). + * + * @since 1.4.0 + */ + def load(): DataFrame = { + val resolved = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + provider = source, + options = extraOptions.toMap) + DataFrame(sqlContext, LogicalRelation(resolved.relation)) + } + + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * @param path input path + * @since 1.4.0 + */ + def json(path: String): DataFrame = format("json").load(path) + + /** + * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) + + /** + * Loads an `RDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: RDD[String]): DataFrame = { + val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble + if (sqlContext.conf.useJacksonStreamingAPI) { + sqlContext.baseRelationToDataFrame( + new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) + } else { + val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord + val appliedSchema = userSpecifiedSchema.getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) + sqlContext.createDataFrame(rowRDD, appliedSchema, needsConversion = false) + } + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def parquet(paths: String*): DataFrame = { + if (paths.isEmpty) { + sqlContext.emptyDataFrame + } else { + val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray + sqlContext.baseRelationToDataFrame( + new ParquetRelation2( + globbedPaths.map(_.toString), None, None, Map.empty[String, String])(sqlContext)) + } + } + + /** + * Returns the specified table as a [[DataFrame]]. + * + * @since 1.4.0 + */ + def table(tableName: String): DataFrame = { + DataFrame(sqlContext, sqlContext.catalog.lookupRelation(Seq(tableName))) + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Option[Seq[String]] = None + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala new file mode 100644 index 0000000000000..b1fc18ac3cb54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -0,0 +1,198 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} + + +/** + * :: Experimental :: + * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, + * key-value stores, etc). + * + * @since 1.4.0 + */ +@Experimental +final class DataFrameWriter private[sql](df: DataFrame) { + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `SaveMode.Overwrite`: overwrite the existing data. + * - `SaveMode.Append`: append the data. + * - `SaveMode.Ignore`: ignore the operation (i.e. no-op). + * - `SaveMode.ErrorIfExists`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: SaveMode): DataFrameWriter = { + this.mode = saveMode + this + } + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `overwrite`: overwrite the existing data. + * - `append`: append the data. + * - `ignore`: ignore the operation (i.e. no-op). + * - `error`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: String): DataFrameWriter = { + saveMode.toLowerCase match { + case "overwrite" => SaveMode.Overwrite + case "append" => SaveMode.Append + case "ignore" => SaveMode.Ignore + case "error" | "default" => SaveMode.ErrorIfExists + case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + + "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + } + this + } + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameWriter = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataFrameWriter = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Saves the content of the [[DataFrame]] at the specified path. + * + * @since 1.4.0 + */ + def save(path: String): Unit = { + this.extraOptions += ("path" -> path) + save() + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * @since 1.4.0 + */ + def save(): Unit = { + ResolvedDataSource( + df.sqlContext, + source, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df) + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * @since 1.4.0 + */ + def saveAsTable(tableName: String): Unit = { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd + } + + /** + * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * This is equivalent to: + * {{{ + * format("json").save(path) + * }}} + * + * @since 1.4.0 + */ + def json(path: String): Unit = format("json").save(path) + + /** + * Saves the content of the [[DataFrame]] in Parquet format at the specified path. + * This is equivalent to: + * {{{ + * format("parquet").save(path) + * }}} + * + * @since 1.4.0 + */ + def parquet(path: String): Unit = format("parquet").save(path) + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var mode: SaveMode = SaveMode.ErrorIfExists + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Option[Seq[String]] = None + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9fb355eb81939..34a50e522c4ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -27,11 +27,9 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import com.google.common.reflect.TypeToken -import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ @@ -43,8 +41,6 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json._ -import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -596,6 +592,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rdd, beanClass) } + /** + * :: Experimental :: + * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. + * {{{ + * sqlContext.read.parquet("/path/to/file.parquet") + * sqlContext.read.schema(schema).json("/path/to/file.json") + * }}} + * + * @group genericdata + * @since 1.4.0 + */ + @Experimental + def read: DataFrameReader = new DataFrameReader(this) + /** * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty * [[DataFrame]] if no paths are passed in. @@ -603,15 +613,13 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group specificdata * @since 1.3.0 */ + @deprecated("Use read.parquet()", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { emptyDataFrame } else if (conf.parquetUseDataSourceApi) { - val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray - baseRelationToDataFrame( - new ParquetRelation2( - globbedPaths.map(_.toString), None, None, Map.empty[String, String])(this)) + read.parquet(paths : _*) } else { DataFrame(this, parquet.ParquetRelation( paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) @@ -625,28 +633,31 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group specificdata * @since 1.3.0 */ - def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String): DataFrame = { + read.json(path) + } /** - * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, * returning the result as a [[DataFrame]]. * * @group specificdata * @since 1.3.0 */ - @Experimental - def jsonFile(path: String, schema: StructType): DataFrame = - load("json", schema, Map("path" -> path)) + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, schema: StructType): DataFrame = { + read.schema(schema).json(path) + } /** - * :: Experimental :: * @group specificdata * @since 1.3.0 */ - @Experimental - def jsonFile(path: String, samplingRatio: Double): DataFrame = - load("json", Map("path" -> path, "samplingRatio" -> samplingRatio.toString)) + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(path) + } /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a @@ -656,8 +667,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group specificdata * @since 1.3.0 */ - def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) - + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a @@ -667,196 +678,131 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group specificdata * @since 1.3.0 */ - def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0) + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** - * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, * returning the result as a [[DataFrame]]. * * @group specificdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - Option(schema).getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } + read.schema(schema).json(json) } /** - * :: Experimental :: * Loads an JavaRDD storing JSON objects (one object per record) and applies the given * schema, returning the result as a [[DataFrame]]. * * @group specificdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - jsonRDD(json.rdd, schema) + read.schema(schema).json(json) } /** - * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) inferring the * schema, returning the result as a [[DataFrame]]. * * @group specificdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } + read.option("samplingRatio", samplingRatio.toString).json(json) } /** - * :: Experimental :: * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the * schema, returning the result as a [[DataFrame]]. * * @group specificdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.json()", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - jsonRDD(json.rdd, samplingRatio); + read.option("samplingRatio", samplingRatio.toString).json(json) } /** - * :: Experimental :: * Returns the dataset stored at path as a DataFrame, * using the default data source configured by spark.sql.sources.default. * * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.load(path)", "1.4.0") def load(path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - load(path, dataSourceName) + read.load(path) } /** - * :: Experimental :: * Returns the dataset stored at path as a DataFrame, using the given data source. * * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.format(source).load(path)", "1.4.0") def load(path: String, source: String): DataFrame = { - load(source, Map("path" -> path)) + read.format(source).load(path) } /** - * :: Experimental :: * (Java-specific) Returns the dataset specified by the given data source and * a set of options as a DataFrame. * * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { - load(source, options.toMap) + read.options(options).format(source).load() } /** - * :: Experimental :: * (Scala-specific) Returns the dataset specified by the given data source and * a set of options as a DataFrame. * * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.format(source).options(options).load()", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - load(source, schema, options.toMap) + read.options(options).format(source).load() } /** - * :: Experimental :: * (Java-specific) Returns the dataset specified by the given data source and * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load( source: String, schema: StructType, - partitionColumns: Array[String], options: java.util.Map[String, String]): DataFrame = { - load(source, schema, partitionColumns, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) + read.format(source).schema(schema).options(options).load() } /** - * :: Experimental :: * (Scala-specific) Returns the dataset specified by the given data source and * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. * @group genericdata * @since 1.3.0 */ - @Experimental + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") def load( source: String, schema: StructType, - partitionColumns: Array[String], options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), partitionColumns, source, options) - DataFrame(this, LogicalRelation(resolved.relation)) + read.format(source).schema(schema).options(options).load() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 9d17516e0ef7d..7a73b6f1ac601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -90,7 +90,7 @@ private[sql] trait ParquetTest { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -102,7 +102,7 @@ private[sql] trait ParquetTest { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.parquetFile(path))) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -128,12 +128,12 @@ private[sql] trait ParquetTest { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + df.write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makePartitionDir( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index b76f7d421f643..6a0bcefe7aa88 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -75,9 +75,9 @@ public void setUp() throws IOException { public void saveAndLoad() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + df.save("json", SaveMode.ErrorIfExists, options); - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); + DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -86,12 +86,12 @@ public void saveAndLoad() { public void saveAndLoadWithSchema() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + df.save("json", SaveMode.ErrorIfExists, options); List fields = new ArrayList(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + DataFrame loadedDF = sqlContext.load("json", schema, options); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1d5f6b3aad6fd..054b23dba84c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -460,14 +460,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 479ad9fe621d0..c5c4f448a7224 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -105,7 +105,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("grouping on nested fields") { - jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -122,7 +122,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6201 IN type conversion") { - jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + read.json( + sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -1199,7 +1200,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - jsonRDD(data).registerTempTable("records") + read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1240,11 +1241,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4322 Grouping field with struct field as sub expression") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) dropTempTable("data") - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) dropTempTable("data") } @@ -1292,7 +1293,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: ORDER BY test for nested fields") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1304,14 +1305,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: special cases") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 2672e20deadc5..dc2d43a197f40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -105,13 +105,13 @@ class UserDefinedTypeSuite extends QueryTest { test("UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.write.parquet(tempDir.getCanonicalPath) } test("Repartition UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) } // Tests to make sure that all operators correctly convert types on the way out. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index b06e3385980f7..6f747e5846f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -215,7 +215,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = jsonRDD(jsonNullStruct) + val jsonDF = read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +234,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = jsonRDD(primitiveFieldAndType) + val jsonDF = read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -262,7 +262,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -361,7 +361,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -377,7 +377,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -451,7 +451,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -504,7 +504,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -528,7 +528,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonDF = jsonRDD(arrayElementTypeConflict) + val jsonDF = read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -556,7 +556,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = jsonRDD(missingFields) + val jsonDF = read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -576,7 +576,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = jsonFile(path, 0.49) + val jsonDF = read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -591,7 +591,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] + read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -603,7 +603,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = jsonFile(path) + val jsonDF = read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -672,7 +672,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = jsonFile(path, schema) + val jsonDF1 = read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -689,7 +689,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -710,7 +710,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -738,7 +738,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -764,7 +764,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -782,7 +782,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -805,7 +805,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = jsonRDD(jsonArray) + val jsonDF = read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -826,7 +826,7 @@ class JsonSuite extends QueryTest { val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = jsonRDD(corruptRecords) + val jsonDF = read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -880,7 +880,7 @@ class JsonSuite extends QueryTest { } test("SPARK-4068: nulls in arrays") { - val jsonDF = jsonRDD(nullsInArrays) + val jsonDF = read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -957,8 +957,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonDF.toJSON) + val jsonDF = read.json(primitiveFieldAndType) + val primTable = read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -970,8 +970,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonDF.toJSON) + val complexJsonDF = read.json(complexFieldAndType1) + val compTable = read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 5ad439584716f..bdc2ebabc5e9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -328,12 +328,12 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.parquetFile(path).filter("part = 1"), + sqlContext.read.parquet(path).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -357,7 +357,7 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 008443df216aa..dd48bb350f26d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -114,24 +114,24 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } // Decimals with precision above 18 are not yet supported intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).collect() } } // Unlimited-length decimals are not yet supported intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).collect() } } } @@ -146,8 +146,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -283,7 +283,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + checkAnswer(read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -311,8 +311,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("save - overwrite") { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) - checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) + checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) } } @@ -320,8 +320,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) - checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) + checkAnswer(read.parquet(file), data.map(Row.fromTuple)) } } @@ -330,8 +330,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) val errorMessage = intercept[Throwable] { - newData.toDF().save( - "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + newData.toDF().write.format("parquet").mode(SaveMode.ErrorIfExists).save(file) }.getMessage assert(errorMessage.contains("already exists")) } @@ -341,8 +340,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) - checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) + checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -374,7 +373,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(parquetFile(path.toString).schema) { + assertResult(read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -392,7 +391,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(configuration) @@ -421,10 +420,10 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// intercept[Throwable] { - sqlContext.parquetFile("file:///nonexistent") + sqlContext.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.parquetFile("hdfs://nonexistent") + sqlContext.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 138e19766dc88..8079c460713da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -155,7 +155,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - parquetFile(base.getCanonicalPath).registerTempTable("t") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -202,7 +202,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - parquetFile(base.getCanonicalPath).registerTempTable("t") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -250,10 +250,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map("path" -> base.getCanonicalPath)) - + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -293,10 +290,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map("path" -> base.getCanonicalPath)) - + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -328,7 +322,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t") + read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 4e54b2eb8df7a..d2d1011b8e917 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -33,7 +33,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index d1d427e1790bd..6f375ef36237d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -33,7 +33,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -109,7 +109,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - jsonRDD(rdd1).registerTempTable("jt1") + read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +121,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - jsonRDD(rdd2).registerTempTable("jt2") + read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -154,13 +154,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite) + table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").save(path.toString, "json", SaveMode.Overwrite) + table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 6567d1acd7644..7a28e9af3673c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -42,7 +42,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = jsonRDD(rdd) + df = read.json(rdd) df.registerTempTable("jsonTable") } @@ -57,41 +57,41 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { def checkLoad(): Unit = { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(load(path.toString), df.collect()) + checkAnswer(read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + checkAnswer(read.format("json").load(path.toString), df.collect()) + checkAnswer(read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - df.save(path.toString) + df.write.save(path.toString) checkLoad() } test("save with path and datasource, and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } test("save and save again") { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) var message = intercept[RuntimeException] { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) }.getMessage assert( @@ -100,14 +100,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) checkLoad() - df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + df.write.mode(SaveMode.Overwrite).json(path.toString) checkLoad() message = intercept[RuntimeException] { - df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + df.write.mode(SaveMode.Append).json(path.toString) }.getMessage assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d46a127d47d31..c6b65106452bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -140,7 +140,7 @@ private[hive] trait HiveStrategies { PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil } else { hiveContext - .parquetFile(partitionLocations: _*) + .read.parquet(partitionLocations: _*) .addPartitioningAttributes(relation.partitionKeys) .lowerCase .where(unresolvedOtherPredicates) @@ -152,7 +152,7 @@ private[hive] trait HiveStrategies { } else { hiveContext - .parquetFile(relation.hiveQlTable.getDataLocation.toString) + .read.parquet(relation.hiveQlTable.getDataLocation.toString) .lowerCase .where(unresolvedOtherPredicates) .select(unresolvedProjection: _*) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 7ff5719adb3ab..5a5ea10e3c82e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -55,8 +55,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -68,8 +68,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 1bf1c1be3e3d3..58b0b80c31e2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -60,7 +60,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + read.json(filePath).collect().toSeq) } test ("persistent JSON table with a user specified schema") { @@ -77,7 +77,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - jsonFile(filePath).registerTempTable("expectedJsonTable") + read.json(filePath).registerTempTable("expectedJsonTable") checkAnswer( sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), @@ -104,7 +104,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { assert(expectedSchema === table("jsonTable").schema) - jsonFile(filePath).registerTempTable("expectedJsonTable") + read.json(filePath).registerTempTable("expectedJsonTable") checkAnswer( sql("SELECT b, ``.`=` FROM jsonTable"), @@ -123,7 +123,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + read.json(filePath).collect().toSeq) } test("drop table") { @@ -138,7 +138,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + read.json(filePath).collect().toSeq) sql("DROP TABLE jsonTable") @@ -241,7 +241,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |) """.stripMargin) - jsonFile(filePath).registerTempTable("expectedJsonTable") + read.json(filePath).registerTempTable("expectedJsonTable") checkAnswer( sql("SELECT * FROM jsonTable"), @@ -474,7 +474,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // Drop table will also delete the data. sql("DROP TABLE savedJsonTable") intercept[InvalidInputException] { - jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) } // Create an external table by specifying the path. @@ -491,7 +491,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // Data should not be deleted after we drop the table. sql("DROP TABLE savedJsonTable") checkAnswer( - jsonFile(tempPath.toString), + read.json(tempPath.toString), df.collect()) conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) @@ -526,7 +526,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // Data should not be deleted. sql("DROP TABLE createdJsonTable") checkAnswer( - jsonFile(tempPath.toString), + read.json(tempPath.toString), df.collect()) // Try to specify the schema. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index b6be09e2f8837..a0075f1e44ca8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -688,11 +688,11 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[Throwable](df2.saveAsParquetFile(filePath)) + intercept[Throwable](df2.write.parquet(filePath)) val df3 = df2.toDF("str", "max_int") - df3.saveAsParquetFile(filePath2) - val df4 = parquetFile(filePath2) + df3.write.parquet(filePath2) + val df4 = read.parquet(filePath2) checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } @@ -731,14 +731,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) .toDF() - .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) partitionedTableDirWithKey = Utils.createTempDir() @@ -747,7 +747,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() @@ -757,7 +757,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithKeyAndComplexTypes( p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithComplexTypes = Utils.createTempDir() @@ -766,7 +766,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index cf6afd25ae5a0..f44b3c521e647 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -92,44 +92,27 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + read.format(dataSourceName) + .option("path", file.getCanonicalPath) + .option("dataSchema", dataSchema.json) + .load(), testDF.collect()) } } test("save()/load() - non-partitioned table - Append") { withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Append) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)).orderBy("a"), + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -147,10 +130,7 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - non-partitioned table - Ignore") { withTempDir { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) + testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -160,89 +140,81 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - partitioned table - simple queries") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json))) + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath)) } } test("save()/load() - partitioned table - Overwrite") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.collect()) } } test("save()/load() - partitioned table - Append") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } test("save()/load() - partitioned table - Append - new partition values") { withTempPath { file => - partitionedTestDF1.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.collect()) } } @@ -250,11 +222,11 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - partitioned table - ErrorIfExists") { withTempDir { file => intercept[RuntimeException] { - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) } } } @@ -343,19 +315,19 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - partitioned table - Overwrite") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { checkAnswer(table("t"), partitionedTestDF.collect()) @@ -363,19 +335,19 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - partitioned table - Append") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -383,19 +355,19 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - partitioned table - Append - new partition values") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { checkAnswer(table("t"), partitionedTestDF.collect()) @@ -403,31 +375,31 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") // Using only a subset of all partition columns intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1")) + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1") + .saveAsTable("t") } // Using different order of partition columns intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p2", "p1")) + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p2", "p1") + .saveAsTable("t") } } @@ -436,12 +408,12 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { withTempTable("t") { intercept[AnalysisException] { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") } } } @@ -450,12 +422,12 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { Seq.empty[(Int, String)].toDF().registerTempTable("t") withTempTable("t") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Ignore) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") assert(table("t").collect().isEmpty) } @@ -463,17 +435,16 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("Hadoop style globbing") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - val df = load( - source = dataSourceName, - options = Map( - "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", - "dataSchema" -> dataSchema.json)) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + val df = read + .format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(s"${file.getCanonicalPath}/p1=*/p2=???") val expectedPaths = Set( s"${file.getCanonicalFile}/p1=1/p2=foo", From d41ae4344c07064de03a120804830886e1614d92 Mon Sep 17 00:00:00 2001 From: FavioVazquez Date: Sat, 16 May 2015 08:07:03 +0100 Subject: [PATCH 074/109] [SPARK-7671] Fix wrong URLs in MLlib Data Types Documentation There is a mistake in the URL of Matrices in the MLlib Data Types documentation (Local matrix scala section), the URL points to https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.linalg.Matrices which is a mistake, since Matrices is an object that implements factory methods for Matrix that does not have a companion class. The correct link should point to https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$ There is another mistake, in the Local Vector section in Scala, Java and Python In the Scala section the URL of Vectors points to the trait Vector (https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and not to the factory methods implemented in Vectors. The correct link should be: https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$ In the Java section the URL of Vectors points to the Interface Vector (https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/linalg/Vector.html) and not to the Class Vectors The correct link should be: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/linalg/Vectors.html In the Python section the URL of Vectors points to the class Vector (https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) and not the Class Vectors The correct link should be: https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors Author: FavioVazquez Closes #6196 from FavioVazquez/fix-typo-matrices-mllib-datatypes and squashes the following commits: 3e9efd5 [FavioVazquez] - Fixed wrong URLs in the MLlib Data Types Documentation 9af7074 [FavioVazquez] Merge remote-tracking branch 'upstream/master' edab1ef [FavioVazquez] Merge remote-tracking branch 'upstream/master' b2e2f8c [FavioVazquez] Merge remote-tracking branch 'upstream/master' --- docs/mllib-data-types.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 4f2a2f71048f7..acec0426dc69b 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -31,7 +31,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseVector) and [`SparseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseVector). We recommend using the factory methods implemented in -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) to create local vectors. +[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -57,7 +57,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVector.html) and [`SparseVector`](api/java/org/apache/spark/mllib/linalg/SparseVector.html). We recommend using the factory methods implemented in -[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vector.html) to create local vectors. +[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. {% highlight java %} import org.apache.spark.mllib.linalg.Vector; @@ -84,7 +84,7 @@ and the following as sparse vectors: with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. {% highlight python %} import numpy as np @@ -241,7 +241,7 @@ The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). We recommend using the factory methods implemented -in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local +in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local matrices. {% highlight scala %} From 1fd33815f47478f5f2e8b55b90757819b8cb5247 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 16 May 2015 08:18:41 +0100 Subject: [PATCH 075/109] [SPARK-4556] [BUILD] binary distribution assembly can't run in local mode Add note on building a runnable distribution with make-distribution.sh Author: Sean Owen Closes #6186 from srowen/SPARK-4556 and squashes the following commits: 4002966 [Sean Owen] Add pointer to --help flag 9fa7883 [Sean Owen] Add note on building a runnable distribution with make-distribution.sh --- docs/building-spark.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index 6e310ff424784..4dbccb9e6e46c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -34,6 +34,16 @@ and in `project/SparkBuild.scala` add: to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/pull/2883/files) if you are unsure of where to add these lines. +# Building a Runnable Distribution + +To create a Spark distribution like those distributed by the +[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +to be runnable, use `make-distribution.sh` in the project root directory. It can be configured +with Maven profile settings and so on like the direct Maven build. Example: + + ./make-distribution.sh --name custom-spark --tgz -Phadoop-2.4 -Pyarn + +For more information on usage, run `./make-distribution.sh --help` # Setting up Maven's Memory Usage From 0ac8b01a07840f199bbc79fb845762284aead6de Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sat, 16 May 2015 08:24:21 +0100 Subject: [PATCH 076/109] [SPARK-7672] [CORE] Use int conversion in translating kryoserializer.buffer.mb to kryoserializer.buffer In translating spark.kryoserializer.buffer.mb to spark.kryoserializer.buffer, use of toDouble will lead to "Fractional values not supported" error even when spark.kryoserializer.buffer.mb is an integer. ilganeli, andrewor14 Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #6198 from nishkamravi2/master_nravi and squashes the following commits: 171a53c [nishkamravi2] Update SparkConfSuite.scala 5261bf6 [Nishkam Ravi] Add a test for deprecated config spark.kryoserializer.buffer.mb 5190f79 [Nishkam Ravi] In translating from deprecated spark.kryoserializer.buffer.mb to spark.kryoserializer.buffer use int conversion since fractions are not permissible 059ce82 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi eaa13b5 [nishkamravi2] Update Client.scala 981afd2 [Nishkam Ravi] Check for read permission before initiating copy 1b81383 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 0f1abd0 [nishkamravi2] Update Utils.scala 474e3bf [nishkamravi2] Update DiskBlockManager.scala 97c383e [nishkamravi2] Update Utils.scala 8691e0c [Nishkam Ravi] Add a try/catch block around Utils.removeShutdownHook 2be1e76 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 1c13b79 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- core/src/main/scala/org/apache/spark/SparkConf.scala | 2 +- core/src/test/scala/org/apache/spark/SparkConfSuite.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a8fc90ad2050e..b5e5d6f1465f3 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -509,7 +509,7 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${s.toDouble * 1000}k")), + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 68d08e32f9aa4..fafa4ed606b08 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -241,6 +241,9 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf.set("spark.yarn.applicationMaster.waitTries", "42") assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) + + conf.set("spark.kryoserializer.buffer.mb", "1.1") + assert(conf.getSizeAsKb("spark.kryoserializer.buffer") === 1100) } test("akka deprecated configs") { From 47e7ffe36b8a8a246fe9af522aff480d19c0c8a6 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 16 May 2015 00:44:29 -0700 Subject: [PATCH 077/109] [SPARK-7655][Core][SQL] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin' Because both `AkkaRpcEndpointRef.ask` and `BroadcastHashJoin` uses `scala.concurrent.ExecutionContext.Implicits.global`. However, because the tasks in `BroadcastHashJoin` are usually long-running tasks, which will occupy all threads in `global`. Then `ask` cannot get a chance to process the replies. For `ask`, actually the tasks are very simple, so we can use `MoreExecutors.sameThreadExecutor()`. For `BroadcastHashJoin`, it's better to use `ThreadUtils.newDaemonCachedThreadPool`. Author: zsxwing Closes #6200 from zsxwing/SPARK-7655-2 and squashes the following commits: cfdc605 [zsxwing] Remove redundant imort and minor doc fix cf83153 [zsxwing] Add "sameThread" and "newDaemonCachedThreadPool with maxThreadNumber" to ThreadUtils 08ad0ee [zsxwing] Remove 'scala.concurrent.ExecutionContext.Implicits.global' in 'ask' and 'BroadcastHashJoin' --- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 ++++--- .../org/apache/spark/util/ThreadUtils.scala | 24 ++++++++++++++++++- .../apache/spark/util/ThreadUtilsSuite.scala | 12 ++++++++++ .../execution/joins/BroadcastHashJoin.scala | 10 ++++++-- 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ba0d468f111ef..0161962cde073 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -29,9 +29,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import com.google.common.util.concurrent.MoreExecutors + import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} /** * A RpcEnv implementation based on Akka. @@ -294,8 +296,8 @@ private[akka] class AkkaRpcEndpointRef( } override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { logError(s"Receive $msg but the sender cannot reply") @@ -305,7 +307,7 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }.mapTo[T] + }(ThreadUtils.sameThread).mapTo[T] } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 098a4b79496b2..ca5624a3d8b3d 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -20,10 +20,22 @@ package org.apache.spark.util import java.util.concurrent._ -import com.google.common.util.concurrent.ThreadFactoryBuilder +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} + +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} private[spark] object ThreadUtils { + private val sameThreadExecutionContext = + ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor()) + + /** + * An `ExecutionContextExecutor` that runs each task in the thread that invokes `execute/submit`. + * The caller should make sure the tasks running in this `ExecutionContextExecutor` are short and + * never block. + */ + def sameThread: ExecutionContextExecutor = sameThreadExecutionContext + /** * Create a thread factory that names threads with a prefix and also sets the threads to daemon. */ @@ -40,6 +52,16 @@ private[spark] object ThreadUtils { Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] } + /** + * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names + * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. + */ + def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + new ThreadPoolExecutor( + 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + } + /** * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a * unique, sequentially assigned integer. diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index a3aa3e953fbec..751d3df9cc8f7 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ + import org.scalatest.FunSuite class ThreadUtilsSuite extends FunSuite { @@ -54,4 +57,13 @@ class ThreadUtilsSuite extends FunSuite { executor.shutdownNow() } } + + test("sameThread") { + val callerThreadName = Thread.currentThread().getName() + val f = Future { + Thread.currentThread().getName() + }(ThreadUtils.sameThread) + val futureThreadName = Await.result(f, 10.seconds) + assert(futureThreadName === callerThreadName) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 05dd5681edfac..fe43fc4125c8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD +import org.apache.spark.util.ThreadUtils import scala.concurrent._ import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Row, Expression} @@ -64,7 +64,7 @@ case class BroadcastHashJoin( val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) - } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) protected override def doExecute(): RDD[Row] = { val broadcastRelation = Await.result(broadcastFuture, timeout) @@ -74,3 +74,9 @@ case class BroadcastHashJoin( } } } + +object BroadcastHashJoin { + + private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 1024)) +} From ce6391296a061bc352386080a2ee96bb63fcc4ac Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 20:55:10 +0800 Subject: [PATCH 078/109] [HOTFIX] [SQL] Fixes DataFrameWriter.mode(String) We forgot an assignment there. /cc rxin Author: Cheng Lian Closes #6212 from liancheng/fix-df-writer and squashes the following commits: 711fbb0 [Cheng Lian] Adds a test case 3b72d78 [Cheng Lian] Fixes DataFrameWriter.mode(String) --- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../scala/org/apache/spark/sql/sources/SaveLoadSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b1fc18ac3cb54..9f42f0f1f4398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -55,7 +55,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter = { - saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 7a28e9af3673c..274c652dd14d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -75,6 +75,13 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { checkLoad() } + test("save with string mode and path, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + path.createNewFile() + df.write.mode("overwrite").save(path.toString) + checkLoad() + } + test("save with path and datasource, and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.json(path.toString) From 1b4e710e5cdb00febb4c5920d81e77c2e3966a8b Mon Sep 17 00:00:00 2001 From: Matthew Brandyberry Date: Sat, 16 May 2015 18:17:48 +0100 Subject: [PATCH 079/109] [BUILD] update jblas dependency version to 1.2.4 jblas 1.2.4 includes native library support for PPC64LE. Author: Matthew Brandyberry Closes #6199 from mtbrandy/jblas-1.2.4 and squashes the following commits: 9df9301 [Matthew Brandyberry] [BUILD] update jblas dependency version to 1.2.4 --- LICENSE | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LICENSE b/LICENSE index d6b9ccf07d999..9d1b00beff748 100644 --- a/LICENSE +++ b/LICENSE @@ -861,7 +861,7 @@ The following components are provided under a BSD-style license. See project lin (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) + (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/pom.xml b/pom.xml index 86aa0a9fa134c..1b45cdb67012a 100644 --- a/pom.xml +++ b/pom.xml @@ -137,7 +137,7 @@ 0.13.1 10.10.1.1 1.6.0rc3 - 1.2.3 + 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 0.5.0 From 161d0b4a41f453b21adde46a86e16c2743752799 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 16 May 2015 15:03:57 -0700 Subject: [PATCH 080/109] [SPARK-7654][MLlib] Migrate MLlib to the DataFrame reader/writer API. Author: Reynold Xin Closes #6211 from rxin/mllib-reader and squashes the following commits: 79a2cb9 [Reynold Xin] [SPARK-7654][MLlib] Migrate MLlib to the DataFrame reader/writer API. --- .../org/apache/spark/examples/mllib/DatasetExample.scala | 2 +- .../scala/org/apache/spark/examples/sql/RDDRelation.scala | 2 +- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 4 ++-- .../mllib/classification/impl/GLMClassificationModel.scala | 2 +- .../apache/spark/mllib/clustering/GaussianMixtureModel.scala | 2 +- .../scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 2 +- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 2 +- .../spark/mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../apache/spark/mllib/regression/IsotonicRegression.scala | 2 +- .../spark/mllib/regression/impl/GLMRegressionModel.scala | 2 +- .../org/apache/spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../apache/spark/mllib/tree/model/treeEnsembleModels.scala | 2 +- 13 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index c95cca7d656e8..520893b26d595 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -103,7 +103,7 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - df.saveAsParquetFile(outputDir) + df.write.parquet(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.read.parquet(outputDir) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index acc89199d5849..b11e32047dc34 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -58,7 +58,7 @@ object RDDRelation { df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - df.saveAsParquetFile("pair.parquet") + df.write.parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.read.parquet("pair.parquet") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index af24ab616663b..ac0ebeceaa1df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -140,7 +140,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { @@ -186,7 +186,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 3b6790cce47c6..d842ec57b2f52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel { // Create Parquet data. val data = Data(weights, intercept, threshold) - sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index c22862c130e77..731b43a1be574 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -126,7 +126,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index ba228b11fcec3..252e166e85cef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -110,7 +110,7 @@ object KMeansModel extends Loader[KMeansModel] { val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => Cluster(id, point) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index aa53e88d59856..1ed01c9d8ba0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -74,7 +74,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = model.assignments.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { @@ -86,7 +86,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) val assignmentsRDD = assignments.map { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 98e83112f52ae..731f7576c2335 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -580,7 +580,7 @@ object Word2VecModel extends Loader[Word2VecModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 88c2148403313..b960fbc5bf5f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -281,8 +281,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) - model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) - model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + model.userFeatures.toDF("id", "features").write.parquet(userPath(path)) + model.productFeatures.toDF("id", "features").write.parquet(productPath(path)) } def load(sc: SparkContext, path: String): MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 4ce541ae5bed9..22b9b22a871f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -184,7 +184,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { sqlContext.createDataFrame( boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } - ).saveAsParquetFile(dataPath(path)) + ).write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index b55944f74f623..2aa0e9ef96d48 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -60,7 +60,7 @@ private[regression] object GLMRegressionModel { val data = Data(weights, intercept) val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 331af428533de..a558f84c8d506 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -223,7 +223,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD: DataFrame = sc.parallelize(nodes) .map(NodeData.apply(0, _)) .toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 8341219bfa71c..f9cd0140fe63f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -414,7 +414,7 @@ private[tree] object TreeEnsembleModel extends Logging { val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** From 3b6ef2c5391b528ef989e24400fbb0c496c3b245 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 16 May 2015 21:03:22 -0700 Subject: [PATCH 081/109] [SPARK-7655][Core] Deserializing value should not hold the TaskSchedulerImpl lock We should not call `DirectTaskResult.value` when holding the `TaskSchedulerImpl` lock. It may cost dozens of seconds to deserialize a large object. Author: zsxwing Closes #6195 from zsxwing/SPARK-7655 and squashes the following commits: 21f502e [zsxwing] Add more comments e25fa88 [zsxwing] Add comments 15010b5 [zsxwing] Deserialize value should not hold the TaskSchedulerImpl lock --- .../apache/spark/scheduler/TaskResult.scala | 23 +++++++++++++++++-- .../spark/scheduler/TaskResultGetter.scala | 4 ++++ .../spark/scheduler/TaskSetManager.scala | 6 +++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 1f114a0207f7b..8b2a742b96988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -40,6 +40,9 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long var metrics: TaskMetrics) extends TaskResult[T] with Externalizable { + private var valueObjectDeserialized = false + private var valueObject: T = _ + def this() = this(null.asInstanceOf[ByteBuffer], null, null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { @@ -72,10 +75,26 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long } } metrics = in.readObject().asInstanceOf[TaskMetrics] + valueObjectDeserialized = false } + /** + * When `value()` is called at the first time, it needs to deserialize `valueObject` from + * `valueBytes`. It may cost dozens of seconds for a large instance. So when calling `value` at + * the first time, the caller should avoid to block other threads. + * + * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. + */ def value(): T = { - val resultSer = SparkEnv.get.serializer.newInstance() - resultSer.deserialize(valueBytes) + if (valueObjectDeserialized) { + valueObject + } else { + // This should not run when holding a lock because it may cost dozens of seconds for a large + // value. + val resultSer = SparkEnv.get.serializer.newInstance() + valueObject = resultSer.deserialize(valueBytes) + valueObjectDeserialized = true + valueObject + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 391827c1d2156..46a6f6537e2ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -54,6 +54,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { return } + // deserialize "value" without holding any lock so that it won't block other threads. + // We should call it here, so that when it's called again in + // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. + directResult.value() (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dc325283d961..c4487d5b37247 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -620,6 +620,12 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { From 517eb37a85e0a28820bcfd5d98c50d02df6521c6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 16 May 2015 22:01:53 -0700 Subject: [PATCH 082/109] [SPARK-7654][SQL] Move JDBC into DataFrame's reader/writer interface. Also moved all the deprecated functions into one place for SQLContext and DataFrame, and updated tests to use the new API. Author: Reynold Xin Closes #6210 from rxin/df-writer-reader-jdbc and squashes the following commits: 7465c2c [Reynold Xin] Fixed unit test. 118e609 [Reynold Xin] Updated tests. 3441b57 [Reynold Xin] Updated javadoc. 13cdd1c [Reynold Xin] [SPARK-7654][SQL] Move JDBC into DataFrame's reader/writer interface. --- .../spark/examples/sql/JavaSparkSQL.java | 4 +- .../org/apache/spark/sql/DataFrame.scala | 284 +++----- .../apache/spark/sql/DataFrameReader.scala | 89 ++- .../apache/spark/sql/DataFrameWriter.scala | 53 +- .../org/apache/spark/sql/SQLContext.scala | 682 +++++++----------- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 30 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 16 +- .../org/apache/spark/sql/jdbc/JdbcUtils.scala | 52 ++ .../org/apache/spark/sql/jdbc/jdbc.scala | 6 +- .../spark/sql/JavaApplySchemaSuite.java | 4 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 10 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 31 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 54 +- .../hive/JavaMetastoreDataSourcesSuite.java | 20 +- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 73 +- .../hive/execution/HiveResolutionSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 8 +- .../apache/spark/sql/hive/parquetSuites.scala | 14 +- .../sql/sources/hadoopFsRelationSuites.scala | 68 +- 20 files changed, 747 insertions(+), 761 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 173633ce059e3..afee279ec32b1 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -94,7 +94,7 @@ public String call(Row row) { System.out.println("=== Data source: Parquet File ==="); // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.saveAsParquetFile("people.parquet"); + schemaPeople.write().parquet("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 55ef357a99f71..27e9af49f0664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConversions._ @@ -40,9 +39,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JacksonGenerator -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -227,10 +225,6 @@ class DataFrame private[sql]( } } - /** Left here for backward compatibility. */ - @deprecated("1.3.0", "use toDF") - def toSchemaRDD: DataFrame = this - /** * Returns the object itself. * @group basic @@ -1299,12 +1293,119 @@ class DataFrame private[sql]( @Experimental def write: DataFrameWriter = new DataFrameWriter(this) + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + * @group output + * @since 1.3.0 + */ + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table. + * Throws an exception if the table already exists. + * @group output + * @since 1.3.0 + */ + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + * @group rdd + * @since 1.3.0 + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JacksonGenerator(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + /** Left here for backward compatibility. */ + @deprecated("use toDF", "1.3.0") + def toSchemaRDD: DataFrame = this + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + */ + @deprecated("Use write.jdbc()", "1.4.0") + def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { + val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + */ + @deprecated("Use write.jdbc()", "1.4.0") + def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + val w = if (overwrite) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @since 1.3.0 */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { @@ -1328,7 +1429,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { @@ -1347,7 +1447,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { @@ -1373,7 +1472,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { @@ -1393,7 +1491,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { @@ -1412,7 +1509,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1437,7 +1533,6 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", "1.4.0") @@ -1454,7 +1549,6 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 */ @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { @@ -1465,7 +1559,6 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output - * @since 1.3.0 */ @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { @@ -1476,7 +1569,6 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { @@ -1487,7 +1579,6 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { @@ -1498,7 +1589,6 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1513,7 +1603,6 @@ class DataFrame private[sql]( * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output - * @since 1.3.0 */ @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( @@ -1523,163 +1612,10 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * @group output - * @since 1.3.0 - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd - } - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table. - * Throws an exception if the table already exists. - * @group output - * @since 1.3.0 - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * Returns the content of the [[DataFrame]] as a RDD of JSON strings. - * @group rdd - * @since 1.3.0 - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val writer = new CharArrayWriter() - // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - new Iterator[String] { - override def hasNext: Boolean = iter.hasNext - override def next(): String = { - JacksonGenerator(rowSchema, gen)(iter.next()) - gen.flush() - - val json = writer.toString - if (hasNext) { - writer.reset() - } else { - gen.close() - } - - json - } - } - } - } - //////////////////////////////////////////////////////////////////////////// - // JDBC Write Support //////////////////////////////////////////////////////////////////////////// - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.3.0 - */ - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - createJDBCTable(url, table, allowExisting, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.4.0 - */ - def createJDBCTable( - url: String, - table: String, - allowExisting: Boolean, - properties: Properties): Unit = { - val conn = DriverManager.getConnection(url, properties) - try { - if (allowExisting) { - val sql = s"DROP TABLE IF EXISTS $table" - conn.prepareStatement(sql).executeUpdate() - } - val schema = JDBCWriteDetails.schemaString(this, url) - val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.3.0 - */ - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - insertIntoJDBC(url, table, overwrite, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.4.0 - */ - def insertIntoJDBC( - url: String, - table: String, - overwrite: Boolean, - properties: Properties): Unit = { - if (overwrite) { - val conn = DriverManager.getConnection(url, properties) - try { - val sql = s"TRUNCATE TABLE $table" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } + // End of eeprecated methods //////////////////////////////////////////////////////////////////////////// - // for Python API //////////////////////////////////////////////////////////////////////////// - /** - * Converts a JavaRDD to a PythonRDD. - */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4d63faad6fb7c..381c10f48f3c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql +import java.util.Properties + import org.apache.hadoop.fs.Path +import org.apache.spark.Partition import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json.{JsonRDD, JSONRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} @@ -31,7 +35,7 @@ import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc). + * key-value stores, etc). Use [[SQLContext.read]] to access this. * * @since 1.4.0 */ @@ -94,6 +98,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { * Specifies the input partitioning. If specified, the underlying data source does not need to * discover the data partitioning scheme, and thus can speed up very large inputs. * + * This is only applicable for Parquet at the moment. + * * @since 1.4.0 */ @scala.annotation.varargs @@ -128,6 +134,87 @@ class DataFrameReader private[sql](sqlContext: SQLContext) { DataFrame(sqlContext, LogicalRelation(resolved.relation)) } + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table and connection properties. + * + * @since 1.4.0 + */ + def jdbc(url: String, table: String, properties: Properties): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null), properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + connectionProperties: Properties): DataFrame = { + val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) + val parts = JDBCRelation.columnPartition(partitioning) + jdbc(url, table, parts, connectionProperties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table using connection properties. The `predicates` parameter gives a list + * expressions suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param predicates Condition in the where clause for each partition. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + predicates: Array[String], + connectionProperties: Properties): DataFrame = { + val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => + JDBCPartition(part, i) : Partition + } + jdbc(url, table, parts, connectionProperties) + } + + private def jdbc( + url: String, + table: String, + parts: Array[Partition], + connectionProperties: Properties): DataFrame = { + val relation = JDBCRelation(url, table, parts, connectionProperties)(sqlContext) + sqlContext.baseRelationToDataFrame(relation) + } + /** * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9f42f0f1f4398..f2e721d4db271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import java.util.Properties + import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc). + * key-value stores, etc). Use [[DataFrame.write]] to access this. * * @since 1.4.0 */ @@ -110,6 +113,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. * + * This is only applicable for Parquet at the moment. + * * @since 1.4.0 */ @scala.annotation.varargs @@ -161,6 +166,52 @@ final class DataFrameWriter private[sql](df: DataFrame) { df.sqlContext.executePlan(cmd).toRdd } + /** + * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the + * table already exists in the external database, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + */ + def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + val conn = JdbcUtils.createConnection(url, connectionProperties) + + try { + var tableExists = JdbcUtils.tableExists(conn, table) + + if (mode == SaveMode.Ignore && tableExists) { + return + } + + if (mode == SaveMode.ErrorIfExists && tableExists) { + sys.error(s"Table $table already exists.") + } + + if (mode == SaveMode.Overwrite && tableExists) { + JdbcUtils.dropTable(conn, table) + tableExists = false + } + + // Create the table if the table didn't exist. + if (!tableExists) { + val schema = JDBCWriteDetails.schemaString(df, url) + val sql = s"CREATE TABLE $table ($schema)" + conn.prepareStatement(sql).executeUpdate() + } + } finally { + conn.close() + } + + JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + } + /** * Saves the content of the [[DataFrame]] in JSON format at the specified path. * This is equivalent to: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 34a50e522c4ca..ac1a800219423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import com.google.common.reflect.TypeToken +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -40,11 +41,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -import org.apache.spark.{Partition, SparkContext} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -531,67 +530,6 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rdd.rdd, beanClass) } - /** - * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * Example: - * {{{ - * import org.apache.spark.sql._ - * import org.apache.spark.sql.types._ - * val sqlContext = new org.apache.spark.sql.SQLContext(sc) - * - * val schema = - * StructType( - * StructField("name", StringType, false) :: - * StructField("age", IntegerType, true) :: Nil) - * - * val people = - * sc.textFile("examples/src/main/resources/people.txt").map( - * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sqlContext. applySchema(people, schema) - * dataFrame.printSchema - * // root - * // |-- name: string (nullable = false) - * // |-- age: integer (nullable = true) - * - * dataFrame.registerTempTable("people") - * sqlContext.sql("select name from people").collect.foreach(println) - * }}} - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - /** * :: Experimental :: * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. @@ -606,205 +544,6 @@ class SQLContext(@transient val sparkContext: SparkContext) @Experimental def read: DataFrameReader = new DataFrameReader(this) - /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.parquet()", "1.4.0") - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else if (conf.parquetUseDataSourceApi) { - read.parquet(paths : _*) - } else { - DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) - } - } - - /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String): DataFrame = { - read.json(path) - } - - /** - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String, schema: StructType): DataFrame = { - read.schema(schema).json(path) - } - - /** - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(path) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. - * - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.load(path)", "1.4.0") - def load(path: String): DataFrame = { - read.load(path) - } - - /** - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.format(source).load(path)", "1.4.0") - def load(path: String, source: String): DataFrame = { - read.format(source).load(path) - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") - def load(source: String, options: Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") - def load( - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") - def load( - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - /** * :: Experimental :: * Creates an external table from the given path and returns the corresponding DataFrame. @@ -903,150 +642,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group ddl_ops * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableName) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table and connection properties. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @param properties connection properties - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int, - properties: Properties): DataFrame = { - val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) - val parts = JDBCRelation.columnPartition(partitioning) - jdbc(url, table, parts, properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - jdbc(url, table, theParts, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table using connection properties. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - theParts: Array[String], - properties: Properties): DataFrame = { - val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => - JDBCPartition(part, i) : Partition - } - jdbc(url, table, parts, properties) - } - - private def jdbc( - url: String, - table: String, - parts: Array[Partition], - properties: Properties): DataFrame = { - val relation = JDBCRelation(url, table, parts, properties)(this) - baseRelationToDataFrame(relation) + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) } /** @@ -1372,6 +985,263 @@ class SQLContext(@transient val sparkContext: SparkContext) } } + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @group specificdata + */ + @deprecated("Use read.parquet()", "1.4.0") + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else if (conf.parquetUseDataSourceApi) { + read.parquet(paths : _*) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } + } + + /** + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String): DataFrame = { + read.json(path) + } + + /** + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, schema: StructType): DataFrame = { + read.schema(schema).json(path) + } + + /** + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(path) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + */ + @deprecated("Use read.load(path)", "1.4.0") + def load(path: String): DataFrame = { + read.load(path) + } + + /** + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + */ + @deprecated("Use read.format(source).load(path)", "1.4.0") + def load(path: String, source: String): DataFrame = { + read.format(source).load(path) + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = + { + read.format(source).schema(schema).options(options).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { + read.format(source).schema(schema).options(options).load() + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. + * + * @group specificdata + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String): DataFrame = { + read.jdbc(url, table, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @group specificdata + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + read.jdbc(url, table, theParts, new Properties) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // End of eeprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 40483d3ec7701..95935ba874a72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -29,7 +29,16 @@ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ +/** + * Data corresponding to one partition of a JDBCRDD. + */ +private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { + override def index: Int = idx +} + + private[sql] object JDBCRDD extends Logging { + /** * Maps a JDBC type to a Catalyst type. This function is called only when * the DriverQuirks class corresponding to your database driver returns null. @@ -168,6 +177,7 @@ private[sql] object JDBCRDD extends Logging { DriverManager.getConnection(url, properties) } } + /** * Build and return JDBCRDD from the given information. * @@ -193,18 +203,14 @@ private[sql] object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { - - val prunedSchema = pruneSchema(schema, requiredColumns) - - return new - JDBCRDD( - sc, - getConnector(driver, url, properties), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) + new JDBCRDD( + sc, + getConnector(driver, url, properties), + pruneSchema(schema, requiredColumns), + fqTable, + requiredColumns, + filters, + parts) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 93e82549f213b..09d6865457df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,26 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager import java.util.Properties import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - -/** - * Data corresponding to one partition of a JDBCRDD. - */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { - override def index: Int = idx -} /** * Instructions on how to partition the table among workers. @@ -152,6 +142,8 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - data.insertIntoJDBC(url, table, overwrite, properties) + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(url, table, properties) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala new file mode 100644 index 0000000000000..cc918c237192b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import scala.util.Try + +/** + * Util functions for JDBC tables. + */ +private[sql] object JdbcUtils { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index c099881a01226..a61790b8472c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -163,8 +163,8 @@ package object jdbc { table: String, properties: Properties = new Properties()) { val quirks = DriverQuirks.get(url) - var nullTypes: Array[Int] = df.schema.fields.map(field => { - var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 + val nullTypes: Array[Int] = df.schema.fields.map { field => + val nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 if (nullType.isEmpty) { field.dataType match { case IntegerType => java.sql.Types.INTEGER @@ -183,7 +183,7 @@ package object jdbc { s"Can't translate null value for field $field") } } else nullType.get - }).toArray + } val rddSchema = df.schema df.foreachPartition { iterator => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index c344a9b095c52..fcb8f5499cf84 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -187,14 +187,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = sqlContext.jsonRDD(jsonRDD); + DataFrame df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); + DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6a0bcefe7aa88..2706e01bd28af 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -67,7 +67,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -75,10 +75,8 @@ public void setUp() throws IOException { public void saveAndLoad() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("json", SaveMode.ErrorIfExists, options); - + df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); - checkAnswer(loadedDF, df.collectAsList()); } @@ -86,12 +84,12 @@ public void saveAndLoad() { public void saveAndLoadWithSchema() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("json", SaveMode.ErrorIfExists, options); + df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); List fields = new ArrayList(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.load("json", schema, options); + DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2abfe7f167f77..5a7b6f0aac6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -221,22 +221,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) + assert(TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size === 3) + assert( + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) + assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + .collect().length === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) assert(rows(0).getInt(2) === 3) @@ -246,7 +249,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -286,24 +289,28 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + val rows = TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test DATE types in cache") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - TestSQLContext - .jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().registerTempTable("mycached_date") + val rows = + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test types for null value") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.NULLTYPES").collect() + val rows = TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 0800eded443de..2e4c12f9da80c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ @@ -90,64 +90,66 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { test("Basic CREATE") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - df.createJDBCTable(url, "TEST.BASICCREATETEST", false) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length) + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) + assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.DROPTEST", false, properties) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.DROPTEST", properties) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) - df2.createJDBCTable(url1, "TEST.DROPTEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url, "TEST.APPENDTEST", false) - df2.insertIntoJDBC(url, "TEST.APPENDTEST", false) - assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length) + df.write.jdbc(url, "TEST.APPENDTEST", new Properties) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) + assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.TRUNCATETEST", false, properties) - df2.insertIntoJDBC(url1, "TEST.TRUNCATETEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { - df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) } } - + test("INSERT to JDBC Datasource") { TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } - + test("INSERT to JDBC Datasource with overwrite") { TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 53ddecf57958b..58fe96adab17e 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -81,7 +81,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -96,7 +96,11 @@ public void tearDown() throws IOException { public void saveExternalTableAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -115,7 +119,11 @@ public void saveExternalTableAndQueryIt() { public void saveExternalTableWithSchemaAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -138,7 +146,11 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { Map options = new HashMap(); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index fc6c3c35037b0..945596db80326 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -162,7 +162,7 @@ class CachedTableSuite extends QueryTest { test("REFRESH TABLE also needs to recache the data (data source tables)") { val tempPath: File = Utils.createTempDir() tempPath.delete() - table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") createExternalTable("refreshTable", tempPath.toString, "parquet") checkAnswer( @@ -172,7 +172,7 @@ class CachedTableSuite extends QueryTest { sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) // Append new data. - table("src").save(tempPath.toString, "parquet", SaveMode.Append) + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) // We are still using the old data. assertCached(table("refreshTable")) checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 58b0b80c31e2e..30db976a3ae74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -409,11 +409,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { val originalDefaultSource = conf.defaultDataSourceName val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + val df = read.json(rdd) conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") + df.write.saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), @@ -443,11 +443,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { val originalDefaultSource = conf.defaultDataSourceName val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + val df = read.json(rdd) conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") + df.write.saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), @@ -455,17 +455,17 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // Right now, we cannot append to an existing JSON table. intercept[RuntimeException] { - df.saveAsTable("savedJsonTable", SaveMode.Append) + df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") } // We can overwrite it. - df.saveAsTable("savedJsonTable", SaveMode.Overwrite) + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), df.collect()) // When the save mode is Ignore, we will do nothing when the table already exists. - df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") assert(df.schema === table("savedJsonTable").schema) checkAnswer( sql("SELECT * FROM savedJsonTable"), @@ -479,11 +479,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { // Create an external table by specifying the path. conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), df.collect()) @@ -501,14 +501,13 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { val originalDefaultSource = conf.defaultDataSourceName val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + val df = read.json(rdd) conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) + df.write.format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") createExternalTable("createdJsonTable", tempPath.toString) @@ -566,7 +565,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + read.json(rdd).registerTempTable("jt") sql( """ |create table test_parquet_ctas STORED AS parquET @@ -601,7 +600,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructType( StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) assert(df1.schema === expectedSchema1) - df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("arrayInParquet") val df2 = createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") @@ -610,10 +609,10 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) assert(df2.schema === expectedSchema2) df2.insertInto("arrayInParquet", overwrite = false) - createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", "parquet", SaveMode.Append) + createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a").write + .mode(SaveMode.Append).saveAsTable("arrayInParquet") refreshTable("arrayInParquet") checkAnswer( @@ -634,7 +633,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructType( StructField("a", mapType1, nullable = true) :: Nil) assert(df1.schema === expectedSchema1) - df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("mapInParquet") val df2 = createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") @@ -644,10 +643,10 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { StructField("a", mapType2, nullable = true) :: Nil) assert(df2.schema === expectedSchema2) df2.insertInto("mapInParquet", overwrite = false) - createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a") - .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("mapInParquet", "parquet", SaveMode.Append) + createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a").write.mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a").write + .format("parquet").mode(SaveMode.Append).saveAsTable("mapInParquet") refreshTable("mapInParquet") checkAnswer( @@ -711,30 +710,30 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { def createDF(from: Int, to: Int): DataFrame = createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") - createDF(0, 9).saveAsTable("insertParquet", "parquet") + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).saveAsTable("insertParquet", "parquet") + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") } - createDF(10, 19).saveAsTable("insertParquet", "parquet", SaveMode.Append) + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).saveAsTable("insertParquet", "parquet", SaveMode.Append) + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("insertParquet") } - createDF(30, 39).saveAsTable("insertParquet", SaveMode.Append) + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) @@ -744,11 +743,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).saveAsTable("insertParquet", SaveMode.Overwrite) + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).saveAsTable("insertParquet", SaveMode.Ignore) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") checkAnswer( sql("SELECT p.c1, c2 FROM insertParquet p"), (50 to 59).map(i => Row(i, s"str$i"))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 8ad3627504229..3dfa6e72e1242 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) @@ -31,14 +31,14 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index dfe73c62c42b9..ca2c4b4019c55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -535,14 +535,14 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4296 Grouping field with Hive UDF as sub expression") { val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") @@ -550,7 +550,7 @@ class SQLQuerySuite extends QueryTest { test("resolve udtf with single alias") { val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } @@ -563,7 +563,7 @@ class SQLQuerySuite extends QueryTest { // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") val originalConf = getConf("spark.sql.hive.convertCTAS", "false") setConf("spark.sql.hive.convertCTAS", "false") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a0075f1e44ca8..05d99983b6a63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -150,9 +150,9 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { } val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - jsonRDD(rdd1).registerTempTable("jt") + read.json(rdd1).registerTempTable("jt") val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - jsonRDD(rdd2).registerTempTable("jt_array") + read.json(rdd2).registerTempTable("jt_array") setConf("spark.sql.hive.convertMetastoreParquet", "true") } @@ -617,16 +617,16 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) - df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) - df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files // and then merge metadata in footers of these four (two outdated ones and two latest one), @@ -663,7 +663,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.saveAsTable("alwaysNullable", "parquet") + df.write.format("parquet").saveAsTable("alwaysNullable") val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) val arrayType2 = ArrayType(IntegerType, containsNull = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index f44b3c521e647..9d9b436cabe3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -120,10 +120,7 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - non-partitioned table - ErrorIfExists") { withTempDir { file => intercept[RuntimeException] { - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.ErrorIfExists) + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) } } } @@ -233,10 +230,8 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - partitioned table - Ignore") { withTempDir { file => - partitionedTestDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) + partitionedTestDF.write + .format(dataSourceName).mode(SaveMode.Ignore).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) val fs = path.getFileSystem(SparkHadoopUtil.get.conf) @@ -249,11 +244,9 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - non-partitioned table - Overwrite") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") withTable("t") { checkAnswer(table("t"), testDF.collect()) @@ -261,15 +254,8 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } test("saveAsTable()/load() - non-partitioned table - Append") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append) + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite).saveAsTable("t") + testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -281,10 +267,7 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { withTempTable("t") { intercept[AnalysisException] { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists) + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") } } } @@ -293,21 +276,16 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { Seq.empty[(Int, String)].toDF().registerTempTable("t") withTempTable("t") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore) - + testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(table("t").collect().isEmpty) } } test("saveAsTable()/load() - partitioned table - simple queries") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) + partitionedTestDF.write.format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") withTable("t") { checkQueries(table("t")) @@ -492,11 +470,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) } } } @@ -518,18 +494,16 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") - .saveAsParquetFile(partitionDir.toString) + .write.parquet(partitionDir.toString) } val dataSchemaWithPartition = StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) } } } From ba4f8ca0d9ccc0a39a8a0105541d0cc1f4912d62 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 23:20:09 -0700 Subject: [PATCH 083/109] [MINOR] [SQL] Removes an unreachable case clause This case clause is already covered by the one above, and generates a compilation warning. Author: Cheng Lian Closes #6214 from liancheng/remove-unreachable-code and squashes the following commits: c38ca7c [Cheng Lian] Removes an unreachable case clause --- sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 1eacdde7413f1..ab33125b74c17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -101,7 +101,6 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } } - case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) => // OK case logical.InsertIntoTable(LogicalRelation(_: HadoopFsRelation), _, _, _, _) => // OK case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. From 1a7b9ce80bb5649796dda48d6a6d662a2809d0ef Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 17 May 2015 00:12:20 -0700 Subject: [PATCH 084/109] [MINOR] Add 1.3, 1.3.1 to master branch EC2 scripts cc pwendell P.S: I can't believe this was outdated all along ? Author: Shivaram Venkataraman Closes #6215 from shivaram/update-ec2-map and squashes the following commits: ae3937a [Shivaram Venkataraman] Add 1.3, 1.3.1 to master branch EC2 scripts --- ec2/spark_ec2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ab4a96f232c13..be92d5f45aa77 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -48,7 +48,7 @@ from urllib.request import urlopen, Request from urllib.error import HTTPError -SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_VERSION = "1.3.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -65,6 +65,8 @@ "1.1.1", "1.2.0", "1.2.1", + "1.3.0", + "1.3.1", ]) SPARK_TACHYON_MAP = { @@ -75,6 +77,8 @@ "1.1.1": "0.5.0", "1.2.0": "0.5.0", "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION From edf09ea1bd4bf7692e0085ad9c70cb1bfc8d06d8 Mon Sep 17 00:00:00 2001 From: scwf Date: Sun, 17 May 2015 15:17:11 +0800 Subject: [PATCH 085/109] [SQL] [MINOR] Skip unresolved expression for InConversion Author: scwf Closes #6145 from scwf/InConversion and squashes the following commits: 5c8ac6b [scwf] minir fix for InConversion --- .../apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index fe0d3f29977c3..b45b17d856fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -296,6 +296,9 @@ trait HiveTypeCoercion { */ object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } From 339905578790fa37fcad9684b859b443313a5aa2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 17 May 2015 15:42:21 +0800 Subject: [PATCH 086/109] [SPARK-7447] [SQL] Don't re-merge Parquet schema when the relation is deserialized JIRA: https://issues.apache.org/jira/browse/SPARK-7447 `MetadataCache` in `ParquetRelation2` is annotated as `transient`. When `ParquetRelation2` is deserialized, we ask `MetadataCache` to refresh and perform schema merging again. It is time-consuming especially for very many parquet files. With the new `FSBasedParquetRelation`, although `MetadataCache` is not `transient` now, `MetadataCache.refresh()` still performs schema merging again when the relation is deserialized. Author: Liang-Chi Hsieh Closes #6012 from viirya/without_remerge_schema and squashes the following commits: 2663957 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into without_remerge_schema 6ac7d93 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into without_remerge_schema b0fc09b [Liang-Chi Hsieh] Don't generate and merge parquetSchema multiple times. --- .../apache/spark/sql/parquet/newParquet.scala | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 946062f6ea64e..bcbdb1ebd236a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -340,7 +340,7 @@ private[sql] class ParquetRelation2( // Schema of the actual Parquet files, without partition columns discovered from partition // directory paths. - var dataSchema: StructType = _ + var dataSchema: StructType = null // Schema of the whole table, including partition columns. var schema: StructType = _ @@ -379,19 +379,23 @@ private[sql] class ParquetRelation2( f -> new Footer(f.getPath, parquetMetadata) }.seq.toMap - dataSchema = { - val dataSchema0 = - maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) + // If we already get the schema, don't need to re-compute it since the schema merging is + // time-consuming. + if (dataSchema == null) { + dataSchema = { + val dataSchema0 = + maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } } } From 50217667cc1239ed3b15f4d10907b727ed85d7fa Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Sun, 17 May 2015 17:03:11 +0100 Subject: [PATCH 087/109] =?UTF-8?q?[SPARK-7669]=20Builds=20against=20Hadoo?= =?UTF-8?q?p=202.6+=20get=20inconsistent=20curator=20depend=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a new profile, `hadoop-2.6`, copying over the hadoop-2.4 properties, updating ZK to 3.4.6 and making the curator version a configurable option. That keeps the curator-recipes JAR in sync with that used in hadoop. There's one more option to consider: making the full curator-client version explicit with its own dependency version. This will pin down the version from hadoop and hive imports Author: Steve Loughran Closes #6191 from steveloughran/stevel/SPARK-7669-hadoop-2.6 and squashes the following commits: e3e281a [Steve Loughran] SPARK-7669 declare the version of curator-client and curator-framework JARs 2901ea9 [Steve Loughran] SPARK-7669 Builds against Hadoop 2.6+ get inconsistent curator dependencies --- pom.xml | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 1b45cdb67012a..6768a039d11e0 100644 --- a/pom.xml +++ b/pom.xml @@ -130,6 +130,7 @@ hbase 1.4.0 3.4.5 + 2.4.0 org.spark-project.hive 0.13.1a @@ -707,7 +708,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} ${hadoop.deps.scope} @@ -716,6 +717,16 @@ + + org.apache.curator + curator-client + ${curator.version} + + + org.apache.curator + curator-framework + ${curator.version} + org.apache.hadoop hadoop-client @@ -1679,6 +1690,17 @@ + + hadoop-2.6 + + 2.6.0 + 0.9.3 + 3.1.1 + 3.4.6 + 2.6.0 + + + yarn @@ -1709,7 +1731,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} org.apache.zookeeper From f2cc6b5bccc3a70fd7d69183b1a068800831fe19 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 17 May 2015 09:30:49 -0700 Subject: [PATCH 088/109] [SPARK-7660] Wrap SnappyOutputStream to work around snappy-java bug This patch wraps `SnappyOutputStream` to ensure that `close()` is idempotent and to guard against write-after-`close()` bugs. This is a workaround for https://github.com/xerial/snappy-java/issues/107, a bug where a non-idempotent `close()` method can lead to stream corruption. We can remove this workaround if we upgrade to a snappy-java version that contains my fix for this bug, but in the meantime this patch offers a backportable Spark fix. Author: Josh Rosen Closes #6176 from JoshRosen/SPARK-7660-wrap-snappy and squashes the following commits: 8b77aae [Josh Rosen] Wrap SnappyOutputStream to fix SPARK-7660 --- .../apache/spark/io/CompressionCodec.scala | 49 ++++++++++++++++++- .../unsafe/UnsafeShuffleWriterSuite.java | 8 --- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0756cdb2ed8e6..0d8ac1f80a9f4 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,7 +17,7 @@ package org.apache.spark.io -import java.io.{InputStream, OutputStream} +import java.io.{IOException, InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} @@ -154,8 +154,53 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt - new SnappyOutputStream(s, blockSize) + new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) } override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } + +/** + * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version + * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. + */ +private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { + + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte]): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b, off, len) + } + + override def flush(): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.flush() + } + + override def close(): Unit = { + if (!closed) { + closed = true + os.close() + } + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 78e52643531e0..730d265c87f88 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -35,7 +35,6 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.xerial.snappy.buffer.CachedBufferAllocator; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; @@ -97,13 +96,6 @@ public OutputStream apply(OutputStream stream) { @After public void tearDown() { Utils.deleteRecursively(tempDir); - // This call is a workaround for SPARK-7660, a snappy-java bug which is exposed by this test - // suite. Clearing the cached buffer allocator's pool of reusable buffers masks this bug, - // preventing a test failure in JavaAPISuite that would otherwise occur. The underlying bug - // needs to be fixed, but in the meantime this workaround avoids spurious Jenkins failures. - synchronized (CachedBufferAllocator.class) { - CachedBufferAllocator.queueTable.clear(); - } final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { fail("Test leaked " + leakedMemory + " bytes of managed memory"); From 564562874f589c4c8bcabcd9d6eb9a6b0eada938 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 17 May 2015 11:59:28 -0700 Subject: [PATCH 089/109] [SPARK-7686] [SQL] DescribeCommand is assigned wrong output attributes in SparkStrategies In `SparkStrategies`, `RunnableDescribeCommand` is called with the output attributes of the table being described rather than the attributes for the `describe` command's output. I discovered this issue because it caused type conversion errors in some UnsafeRow conversion code that I'm writing. Author: Josh Rosen Closes #6217 from JoshRosen/SPARK-7686 and squashes the following commits: 953a344 [Josh Rosen] Fix SPARK-7686 with a simple change in SparkStrategies. a4eec9f [Josh Rosen] Add failing regression test for SPARK-7686 --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 4 ++-- .../scala/org/apache/spark/sql/sources/DDLTestSuite.scala | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index af0029cb84f9a..3f6a0345bc17d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -354,10 +354,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case LogicalDescribeCommand(table, isExtended) => + case describe @ LogicalDescribeCommand(table, isExtended) => val resultPlan = self.sqlContext.executePlan(table).executedPlan ExecutedCommand( - RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 6664e8d64c13a..f5106f67a08df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -99,4 +99,10 @@ class DDLTestSuite extends DataSourceTest { Row("arrayType", "array", ""), Row("structType", "struct", "") )) + + test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { + val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output + assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) + assert(attributes.map(_.dataType).toSet === Set(StringType)) + } } From 2ca60ace8f42cf0bd4569d86c86c37a8a2b6a37c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 May 2015 12:43:15 -0700 Subject: [PATCH 090/109] [SPARK-7491] [SQL] Allow configuration of classloader isolation for hive Author: Michael Armbrust Closes #6167 from marmbrus/configureIsolation and squashes the following commits: 6147cbe [Michael Armbrust] filter other conf 22cc3bc7 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into configureIsolation 07476ee [Michael Armbrust] filter empty prefixes dfdf19c [Michael Armbrust] [SPARK-6906][SQL] Allow configuration of classloader isolation for hive --- .../apache/spark/sql/hive/HiveContext.scala | 33 +++++++++++++++++-- .../hive/client/IsolatedClientLoader.scala | 14 ++++---- .../apache/spark/sql/hive/test/TestHive.scala | 9 ++++- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d98c36e947a1..2733ebdb95bca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -122,6 +122,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS, "builtin") + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = + getConf("spark.sql.hive.metastore.sharedPrefixes", jdbcPrefixes) + .split(",").filterNot(_ == "") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc").mkString(",") + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = + getConf("spark.sql.hive.metastore.barrierPrefixes", "") + .split(",").filterNot(_ == "") + @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -179,12 +202,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else if (hiveMetastoreJars == "maven") { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig ) + IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) } else { // Convert to files and expand any directories. val jars = @@ -210,7 +235,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } isolatedLoader.client } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7f94c93ba49c1..196a3d836cab2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -56,8 +56,7 @@ private[hive] object IsolatedClientLoader { (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil)) .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+ "com.google.guava:guava:14.0.1" :+ - "org.apache.hadoop:hadoop-client:2.4.0" :+ - "mysql:mysql-connector-java:5.1.12" + "org.apache.hadoop:hadoop-client:2.4.0" val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -106,7 +105,9 @@ private[hive] class IsolatedClientLoader( val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, - val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader) + val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, + val sharedPrefixes: Seq[String] = Seq.empty, + val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { // Check to make sure that the root classloader does not know about Hive. @@ -122,13 +123,14 @@ private[hive] class IsolatedClientLoader( name.startsWith("scala.") || name.startsWith("com.google") || name.startsWith("java.lang.") || - name.startsWith("java.net") + name.startsWith("java.net") || + sharedPrefixes.exists(name.startsWith) /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith("org.apache.spark.sql.hive.execution.PairSerDe") || name.startsWith(classOf[ClientWrapper].getName) || - name.startsWith(classOf[ReflectionMagic].getName) + name.startsWith(classOf[ReflectionMagic].getName) || + barrierPrefixes.exists(name.startsWith) protected def classToPath(name: String): String = name.replaceAll("\\.", "/") + ".class" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 1598d4bd47550..964828407481e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -48,7 +48,14 @@ import scala.collection.JavaConversions._ // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) + new SparkContext( + "local[2]", + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set( + "spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe"))) /** * A locally running test instance of Spark's Hive execution engine. From ca4257aec658aaa87f4f097dd7534033d5f13ddc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 17 May 2015 16:49:07 -0700 Subject: [PATCH 091/109] [SPARK-6514] [SPARK-5960] [SPARK-6656] [SPARK-7679] [STREAMING] [KINESIS] Updates to the Kinesis API SPARK-6514 - Use correct region SPARK-5960 - Allow AWS Credentials to be directly passed SPARK-6656 - Specify kinesis application name explicitly SPARK-7679 - Upgrade to latest KCL and AWS SDK. Author: Tathagata Das Closes #6147 from tdas/kinesis-api-update and squashes the following commits: f23ea77 [Tathagata Das] Updated versions and updated APIs 373b201 [Tathagata Das] Updated Kinesis API --- .../kinesis/KinesisCheckpointState.scala | 2 +- .../streaming/kinesis/KinesisReceiver.scala | 152 +++++----- .../kinesis/KinesisRecordProcessor.scala | 32 ++- .../streaming/kinesis/KinesisUtils.scala | 263 +++++++++++++++--- .../kinesis/KinesisReceiverSuite.scala | 15 +- pom.xml | 4 +- 6 files changed, 348 insertions(+), 120 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 588e86a1887ec..1c9b0c218ae18 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -48,7 +48,7 @@ private[kinesis] class KinesisCheckpointState( /** * Advance the checkpoint clock by the checkpoint interval. */ - def advanceCheckpoint() = { + def advanceCheckpoint(): Unit = { checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a7fe4476cacb8..01608fbd3fd31 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,32 +16,31 @@ */ package org.apache.spark.streaming.kinesis -import java.net.InetAddress import java.util.UUID +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} + import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import com.amazonaws.auth.AWSCredentialsProvider -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +private[kinesis] +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends BasicAWSCredentials(accessKeyId, secretKey) with Serializable /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) - * as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers - * to run within a Spark Executor. + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, @@ -49,6 +48,8 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Region name used by the Kinesis Client Library for + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -59,92 +60,103 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). * @param storageLevel Storage level to use for storing the received objects - * - * @return ReceiverInputDStream[Array[Byte]] + * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies + * the credentials */ private[kinesis] class KinesisReceiver( appName: String, streamName: String, endpointUrl: String, - checkpointInterval: Duration, + regionName: String, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel) - extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => - - /* - * The following vars are built in the onStart() method which executes in the Spark Worker after - * this code is serialized and shipped remotely. - */ - - /* - * workerId should be based on the ip address of the actual Spark Worker where this code runs - * (not the Driver's ip address.) - */ - var workerId: String = null + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => /* - * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file at the default location (~/.aws/credentials) shared by all - * AWS SDKs and the AWS CLI - * Instance profile credentials delivered through the Amazon EC2 metadata service + * ================================================================================= + * The following vars are initialize in the onStart() method which executes in the + * Spark worker after this Receiver is serialized and shipped to the worker. + * ================================================================================= */ - var credentialsProvider: AWSCredentialsProvider = null - - /* KCL config instance. */ - var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. + /** + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker where this code runs + * (not the driver's IP address.) */ - var recordProcessorFactory: IRecordProcessorFactory = null + private var workerId: String = null - /* - * Create a Kinesis Worker. - * This is the core client abstraction from the Kinesis Client Library (KCL). - * We pass the RecordProcessorFactory from above as well as the KCL config instance. - * A Kinesis Worker can process 1..* shards from the given stream - each with its - * own RecordProcessor. + /** + * Worker is the core client abstraction from the Kinesis Client Library (KCL). + * A worker can process more than one shards from the given stream. + * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * processors. */ - var worker: Worker = null + private var worker: Worker = null /** - * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through the Worker.run() - * method. + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { workerId = Utils.localHostName() + ":" + UUID.randomUUID() - credentialsProvider = new DefaultAWSCredentialsProviderChain() - kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, - credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) - recordProcessorFactory = new IRecordProcessorFactory { + + // KCL config instance + val awsCredProvider = resolveAWSCredentialsProvider() + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + val recordProcessorFactory = new IRecordProcessorFactory { override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, workerId, new KinesisCheckpointState(checkpointInterval)) } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) worker.run() + logInfo(s"Started receiver with workerId $workerId") } /** - * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. - * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - worker.shutdown() - logInfo(s"Shut down receiver with workerId $workerId") + if (worker != null) { + worker.shutdown() + logInfo(s"Stopped receiver for workerId $workerId") + worker = null + } workerId = null - credentialsProvider = null - kinesisClientLibConfiguration = null - recordProcessorFactory = null - worker = null + } + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { + awsCredentialsOption match { + case Some(awsCredentials) => + logInfo("Using provided AWS credentials") + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + logInfo("Using DefaultAWSCredentialsProviderChain") + new DefaultAWSCredentialsProviderChain() + } } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index af8cd875b4541..f65e743c4e2a3 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,7 +35,10 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes @@ -47,8 +50,8 @@ private[kinesis] class KinesisRecordProcessor( workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - /* shardId to be populated during initialize() */ - var shardId: String = _ + // shardId to be populated during initialize() + private var shardId: String = _ /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. @@ -56,8 +59,8 @@ private[kinesis] class KinesisRecordProcessor( * @param shardId assigned by the KCL to this particular RecordProcessor. */ override def initialize(shardId: String) { - logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") } /** @@ -73,12 +76,17 @@ private[kinesis] class KinesisRecordProcessor( if (!receiver.isStopped()) { try { /* - * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - */ + * Notes: + * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + * 3) For performance, the BlockGenerator is asynchronously queuing elements within its + * memory before creating blocks. This prevents the small block scenario, but requires + * that you register callbacks to know when a block has been generated and stored + * (WAL is sufficient for storage) before can checkpoint back to the source. + */ batch.foreach(record => receiver.store(record.getData().array())) logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") @@ -116,7 +124,7 @@ private[kinesis] class KinesisRecordProcessor( logError(s"Exception: WorkerId $workerId encountered and exception while storing " + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e } } @@ -190,7 +198,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) } - /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ case _: ShutdownException => { logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 96f4399accd3a..b114bcff92d0f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,29 +16,75 @@ */ package org.apache.spark.streaming.kinesis -import org.apache.spark.annotation.Experimental +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.apache.spark.streaming.{Duration, StreamingContext} -/** - * Helper class to create Amazon Kinesis Input Stream - * :: Experimental :: - */ -@Experimental object KinesisUtils { /** - * Create an InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: - * @param ssc StreamingContext object + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, None)) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -48,28 +94,84 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return ReceiverInputDStream[Array[Byte]] + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param ssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( ssc: StreamingContext, streamName: String, endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, - checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, checkpointInterval, storageLevel, None)) } /** - * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -79,19 +181,116 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return JavaReceiverInputDStream[Array[Byte]] + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param jssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + @deprecated("use other forms of createStream", "1.4.0") + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { - jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, - endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream( + jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) + } + + private def getRegionByEndpoint(endpointUrl: String): String = { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } + + private def validateRegion(regionName: String): String = { + Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { + throw new IllegalArgumentException(s"Region name '$regionName' is not valid") + } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 255fe65819608..7c17ee9dceddd 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -40,6 +40,7 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -81,12 +82,20 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("kinesis utils api") { + test("KinesisUtils API") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving - val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + ssc.stop() } diff --git a/pom.xml b/pom.xml index 6768a039d11e0..6f525b6ac81a3 100644 --- a/pom.xml +++ b/pom.xml @@ -148,8 +148,8 @@ 1.7.7 hadoop2 0.7.1 - 1.8.3 - 1.1.0 + 1.9.16 + 1.2.1 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt From 2f22424e9f6624097b292cb70e00787b69d80718 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 17 May 2015 16:51:57 -0700 Subject: [PATCH 092/109] [SQL] [MINOR] use catalyst type converter in ScalaUdf It's a follow-up of https://github.com/apache/spark/pull/5154, we can speed up scala udf evaluation by create type converter in advance. Author: Wenchen Fan Closes #6182 from cloud-fan/tmp and squashes the following commits: 241cfe9 [Wenchen Fan] use converter in ScalaUdf --- .../org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 9a77ca624ebe2..d22eb10ad399f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -956,7 +956,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - - override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) + val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: Row): Any = converter(f(input)) } From ff71d34e00b64d70f671f9bf3e63aec39cd525e5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 17 May 2015 20:37:19 -0700 Subject: [PATCH 093/109] [SPARK-7693][Core] Remove "import scala.concurrent.ExecutionContext.Implicits.global" Learnt a lesson from SPARK-7655: Spark should avoid to use `scala.concurrent.ExecutionContext.Implicits.global` because the user may submit blocking actions to `scala.concurrent.ExecutionContext.Implicits.global` and exhaust all threads in it. This could crash Spark. So Spark should always use its own thread pools for safety. This PR removes all usages of `scala.concurrent.ExecutionContext.Implicits.global` and uses proper thread pools to replace them. Author: zsxwing Closes #6223 from zsxwing/SPARK-7693 and squashes the following commits: a33ff06 [zsxwing] Decrease the max thread number from 1024 to 128 cf4b3fc [zsxwing] Remove "import scala.concurrent.ExecutionContext.Implicits.global" --- .../CoarseGrainedExecutorBackend.scala | 9 +++--- .../apache/spark/rdd/AsyncRDDActions.scala | 13 +++++++-- .../apache/spark/storage/BlockManager.scala | 17 ++++++++--- .../spark/storage/BlockManagerMaster.scala | 29 ++++++++++++------- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../receiver/ReceiverSupervisor.scala | 14 ++++++--- 6 files changed, 58 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ed159dec4f998..f3a26f54a81fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -55,18 +55,19 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { - import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) - } onComplete { + }(ThreadUtils.sameThread).onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) - } + }(ThreadUtils.sameThread) } def extractLogUrls: Map[String, String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec185340c3a2d..bbf1b83af0795 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,8 +19,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.util.ThreadUtils + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.ExecutionContext import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -66,6 +68,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val f = new ComplexFutureAction[Seq[T]] f.run { + // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which + // is a cached thread pool. val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -101,7 +105,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi partsScanned += numPartsToTry } results.toSeq - } + }(AsyncRDDActions.futureExecutionContext) f } @@ -123,3 +127,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi (index, data) => Unit, Unit) } } + +private object AsyncRDDActions { + val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cc794e5c90ffa..16d67cbfca80b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,8 +21,7 @@ import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ import scala.util.Random @@ -77,6 +76,9 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) + // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) @@ -266,11 +268,13 @@ private[spark] class BlockManager( asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool reregister() asyncReregisterLock.synchronized { asyncReregisterTask = null } - } + }(futureExecutionContext) } } } @@ -744,7 +748,11 @@ private[spark] class BlockManager( case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, putLevel) } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) case _ => null } @@ -1218,6 +1226,7 @@ private[spark] class BlockManager( } metadataCleaner.cancel() broadcastCleaner.cancel() + futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index a85e1c7632973..abcad9438bf28 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,13 +17,14 @@ package org.apache.spark.storage +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.RpcUtils +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( @@ -102,8 +103,8 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") - } + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { Await.result(future, timeout) } @@ -114,8 +115,8 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") - } + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { Await.result(future, timeout) } @@ -128,8 +129,8 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") - } + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { Await.result(future, timeout) } @@ -169,11 +170,17 @@ class BlockManagerMaster( val response = driverEndpoint. askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = Await.result(Future.sequence(futures), timeout) - if (result == null) { + implicit val sameThread = ThreadUtils.sameThread + val cbf = + implicitly[ + CanBuildFrom[Iterable[Future[Option[BlockStatus]]], + Option[BlockStatus], + Iterable[Option[BlockStatus]]]] + val blockStatus = Await.result( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout) + if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } - val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => status.map { s => (blockManagerId, s) } }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index fe43fc4125c8e..b8b12be8756f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -78,5 +78,5 @@ case class BroadcastHashJoin( object BroadcastHashJoin { private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 1024)) + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 4943f29395d12..33be067ebdaf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -18,14 +18,14 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import java.util.concurrent.CountDownLatch -import scala.concurrent._ -import ExecutionContext.Implicits.global +import org.apache.spark.util.ThreadUtils /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -46,6 +46,9 @@ private[streaming] abstract class ReceiverSupervisor( // Attach the executor to the receiver receiver.attachExecutor(this) + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) + /** Receiver id */ protected val streamId = receiver.streamId @@ -111,6 +114,7 @@ private[streaming] abstract class ReceiverSupervisor( stoppingError = error.orNull stopReceiver(message, error) onStop(message, error) + futureExecutionContext.shutdownNow() stopLatch.countDown() } @@ -150,6 +154,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Restart receiver with delay */ def restartReceiver(message: String, error: Option[Throwable], delay: Int) { Future { + // This is a blocking action so we should use "futureExecutionContext" which is a cached + // thread pool. logWarning("Restarting receiver with delay " + delay + " ms: " + message, error.getOrElse(null)) stopReceiver("Restarting receiver with delay " + delay + "ms: " + message, error) @@ -158,7 +164,7 @@ private[streaming] abstract class ReceiverSupervisor( logInfo("Starting receiver again") startReceiver() logInfo("Receiver started again") - } + }(futureExecutionContext) } /** Check if receiver has been marked for stopping */ From 775e6f9909d4495cbc11c377508b43482d782742 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Sun, 17 May 2015 21:16:52 -0700 Subject: [PATCH 094/109] [SPARK-7694] [MLLIB] Use getOrElse for getting the threshold of LR model The `toString` method of `LogisticRegressionModel` calls `get` method on an Option (threshold) without a safeguard. In spark-shell, the following code `val model = algorithm.run(data).clearThreshold()` in lbfgs code will fail as `toString `method will be called right after `clearThreshold()` to show the results in the REPL. Author: Shuo Xiang Closes #6224 from coderxiang/getorelse and squashes the following commits: d5f53c9 [Shuo Xiang] use getOrElse for getting the threshold of LR model 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- .../apache/spark/mllib/classification/LogisticRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index bd2e9079ce1ae..2df4d21e8cd55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -163,7 +163,7 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" override def toString: String = { - s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}" + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } From e32c0f69f38ad729e25c2d5f90eb73b4453f8279 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 May 2015 01:10:55 -0700 Subject: [PATCH 095/109] [SPARK-7299][SQL] Set precision and scale for Decimal according to JDBC metadata instead of returned BigDecimal JIRA: https://issues.apache.org/jira/browse/SPARK-7299 When connecting with oracle db through jdbc, the precision and scale of `BigDecimal` object returned by `ResultSet.getBigDecimal` is not correctly matched to the table schema reported by `ResultSetMetaData.getPrecision` and `ResultSetMetaData.getScale`. So in case you insert a value like `19999` into a column with `NUMBER(12, 2)` type, you get through a `BigDecimal` object with scale as 0. But the dataframe schema has correct type as `DecimalType(12, 2)`. Thus, after you save the dataframe into parquet file and then retrieve it, you will get wrong result `199.99`. Because it is reported to be problematic on jdbc connection with oracle db. It might be difficult to add test case for it. But according to the user's test on JIRA, it solves this problem. Author: Liang-Chi Hsieh Closes #5833 from viirya/jdbc_decimal_precision and squashes the following commits: 69bc2b5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into jdbc_decimal_precision 928f864 [Liang-Chi Hsieh] Add comments. 5f9da94 [Liang-Chi Hsieh] Set up Decimal's precision and scale according to table schema instead of returned BigDecimal. --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 95935ba874a72..4189dfcf956c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -300,7 +300,7 @@ private[sql] class JDBCRDD( abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case object DecimalConversion extends JDBCConversion + case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -317,8 +317,8 @@ private[sql] class JDBCRDD( schema.fields.map(sf => sf.dataType match { case BooleanType => BooleanConversion case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion - case DecimalType.Fixed(d) => DecimalConversion + case DecimalType.Unlimited => DecimalConversion(None) + case DecimalType.Fixed(d) => DecimalConversion(Some(d)) case DoubleType => DoubleConversion case FloatType => FloatConversion case IntegerType => IntegerConversion @@ -375,7 +375,22 @@ private[sql] class JDBCRDD( } else { mutableRow.update(i, null) } - case DecimalConversion => + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal + // object returned by ResultSet.getBigDecimal is not correctly matched to the table + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then + // retrieve it, you will get wrong result 199.99. + // So it is needed to set precision and scale for Decimal based on JDBC metadata. + case DecimalConversion(Some((p, s))) => + val decimalVal = rs.getBigDecimal(pos) + if (decimalVal == null) { + mutableRow.update(i, null) + } else { + mutableRow.update(i, Decimal(decimalVal, p, s)) + } + case DecimalConversion(None) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) From 1ecfac6e387b0934bfb5a9bbb4ad74b81ec210a4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 18 May 2015 08:35:14 -0700 Subject: [PATCH 096/109] [SPARK-6657] [PYSPARK] Fix doc warnings Fixed the following warnings in `make clean html` under `python/docs`: ~~~ /Users/meng/src/spark/python/pyspark/mllib/evaluation.py:docstring of pyspark.mllib.evaluation.RankingMetrics.ndcgAt:3: ERROR: Unexpected indentation. /Users/meng/src/spark/python/pyspark/mllib/evaluation.py:docstring of pyspark.mllib.evaluation.RankingMetrics.ndcgAt:4: WARNING: Block quote ends without a blank line; unexpected unindent. /Users/meng/src/spark/python/pyspark/mllib/fpm.py:docstring of pyspark.mllib.fpm.FPGrowth.train:3: ERROR: Unexpected indentation. /Users/meng/src/spark/python/pyspark/mllib/fpm.py:docstring of pyspark.mllib.fpm.FPGrowth.train:4: WARNING: Block quote ends without a blank line; unexpected unindent. /Users/meng/src/spark/python/pyspark/sql/__init__.py:docstring of pyspark.sql.DataFrame.replace:16: WARNING: Field list ends without a blank line; unexpected unindent. /Users/meng/src/spark/python/pyspark/streaming/kafka.py:docstring of pyspark.streaming.kafka.KafkaUtils.createRDD:8: ERROR: Unexpected indentation. /Users/meng/src/spark/python/pyspark/streaming/kafka.py:docstring of pyspark.streaming.kafka.KafkaUtils.createRDD:9: WARNING: Block quote ends without a blank line; unexpected unindent. ~~~ davies Author: Xiangrui Meng Closes #6221 from mengxr/SPARK-6657 and squashes the following commits: e3f83fe [Xiangrui Meng] fix sql and streaming doc warnings 2b4371e [Xiangrui Meng] fix mllib python doc warnings --- python/pyspark/mllib/evaluation.py | 5 ++--- python/pyspark/mllib/fpm.py | 12 ++++++------ python/pyspark/sql/dataframe.py | 1 + python/pyspark/streaming/kafka.py | 3 ++- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4c777f2180dc9..a5e5ddc8fe506 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -334,11 +334,10 @@ def ndcgAt(self, k): """ Compute the average NDCG value of all the queries, truncated at ranking position k. The discounted cumulative gain at position k is computed as: - sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current implementation, the relevance value is binary. - - If a query has an empty ground truth set, zero will be used as ndcg together with + If a query has an empty ground truth set, zero will be used as NDCG together with a log warning. """ return self.call("ndcgAt", int(k)) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index d8df02bdbaba9..bdc4a132b1b18 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -61,12 +61,12 @@ class FPGrowth(object): def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: The input data set, each element - contains a transaction. - :param minSupport: The minimal support level - (default: `0.3`). - :param numPartitions: The number of partitions used by parallel - FP-growth (default: same as input data). + + :param data: The input data set, each element contains a + transaction. + :param minSupport: The minimal support level (default: `0.3`). + :param numPartitions: The number of partitions used by + parallel FP-growth (default: same as input data). """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 96d927b9ba35c..e4a191a9ef07f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -943,6 +943,7 @@ def replace(self, to_replace, value, subset=None): Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. + >>> df4.replace(10, 20).show() +----+------+-----+ | age|height| name| diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index e278b29003f69..10a859a532e28 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -132,11 +132,12 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, .. note:: Experimental Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object :param kafkaParams: Additional params for Kafka :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty - map, in which case leaders will be looked up on the driver. + map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object From 814b3dabdf01abc7a2f25aa32284caccadeb7798 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Mon, 18 May 2015 08:46:33 -0700 Subject: [PATCH 097/109] [SPARK-7272] [MLLIB] User guide for PMML model export https://issues.apache.org/jira/browse/SPARK-7272 Author: Vincenzo Selvaggio Closes #6219 from selvinsource/mllib_pmml_model_export_SPARK-7272 and squashes the following commits: c866fb8 [Vincenzo Selvaggio] Update mllib-pmml-model-export.md 1beda98 [Vincenzo Selvaggio] [SPARK-7272] Initial user guide for pmml export d670662 [Vincenzo Selvaggio] Update mllib-pmml-model-export.md 2731375 [Vincenzo Selvaggio] Update mllib-pmml-model-export.md 680dc33 [Vincenzo Selvaggio] Update mllib-pmml-model-export.md 2e298b5 [Vincenzo Selvaggio] Update mllib-pmml-model-export.md a932f51 [Vincenzo Selvaggio] Create mllib-pmml-model-export.md --- docs/mllib-guide.md | 1 + docs/mllib-pmml-model-export.md | 86 +++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 docs/mllib-pmml-model-export.md diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index f8e879496c135..de7d66fb2dedf 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -39,6 +39,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) +* [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md new file mode 100644 index 0000000000000..42ea2ca81f80d --- /dev/null +++ b/docs/mllib-pmml-model-export.md @@ -0,0 +1,86 @@ +--- +layout: global +title: PMML model export - MLlib +displayTitle: MLlib - PMML model export +--- + +* Table of contents +{:toc} + +## MLlib supported models + +MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). + +The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. + + + + + + + + + + + + + + + + + + + + + + + + + +
MLlib modelPMML model
KMeansModelClusteringModel
LinearRegressionModelRegressionModel (functionName="regression")
RidgeRegressionModelRegressionModel (functionName="regression")
LassoModelRegressionModel (functionName="regression")
SVMModelRegressionModel (functionName="classification" normalizationMethod="none")
Binary LogisticRegressionModelRegressionModel (functionName="classification" normalizationMethod="logit")
+ +## Examples +
+ +
+To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. + +Here a complete example of building a KMeansModel and print it out in PMML format: +{% highlight scala %} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/kmeans_data.txt") +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using KMeans +val numClusters = 2 +val numIterations = 20 +val clusters = KMeans.train(parsedData, numClusters, numIterations) + +// Export to PMML +println("PMML Model:\n" + clusters.toPMML) +{% endhighlight %} + +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: + +{% highlight scala %} +// Export the model to a String in PMML format +clusters.toPMML + +// Export the model to a local file in PMML format +clusters.toPMML("/tmp/kmeans.xml") + +// Export the model to a directory on a distributed file system in PMML format +clusters.toPMML(sc,"/tmp/kmeans") + +// Export the model to the OutputStream in PMML format +clusters.toPMML(System.out) +{% endhighlight %} + +For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. + +
+ +
From 563bfcc1ab1b1c79b1845230c8c600db85a08fe3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 18 May 2015 10:59:35 -0700 Subject: [PATCH 098/109] [SPARK-7627] [SPARK-7472] DAG visualization: style skipped stages This patch fixes two things: **SPARK-7627.** Cached RDDs no longer light up on the job page. This is a simple fix. **SPARK-7472.** Display skipped stages differently from normal stages. The latter is a major UX issue. Because we link the job viz to the stage viz even for skipped stages, the user may inadvertently click into the stage page of a skipped stage, which is empty. ------------------- Author: Andrew Or Closes #6171 from andrewor14/dag-viz-skipped and squashes the following commits: f261797 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped 0eda358 [Andrew Or] Tweak skipped stage border color c604150 [Andrew Or] Tweak grayscale colors 7010676 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped 762b541 [Andrew Or] Use special prefix for stage clusters to avoid collisions 51c95b9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped b928cd4 [Andrew Or] Fix potential leak + write tests for it 7c4c364 [Andrew Or] Show skipped stages differently 7cc34ce [Andrew Or] Merge branch 'master' of github.com:apache/spark into dag-viz-skipped c121fa2 [Andrew Or] Fix cache color --- .../apache/spark/ui/static/spark-dag-viz.css | 71 +++--- .../apache/spark/ui/static/spark-dag-viz.js | 50 ++-- .../scala/org/apache/spark/ui/UIUtils.scala | 6 +- .../spark/ui/scope/RDDOperationGraph.scala | 10 +- .../ui/scope/RDDOperationGraphListener.scala | 96 ++++++-- .../RDDOperationGraphListenerSuite.scala | 227 ++++++++++++++---- 6 files changed, 352 insertions(+), 108 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index eedefb44b96fc..3b4ae2ed354b8 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -15,32 +15,21 @@ * limitations under the License. */ -#dag-viz-graph svg path { - stroke: #444; - stroke-width: 1.5px; -} - -#dag-viz-graph svg g.cluster rect { - stroke-width: 1px; -} - -#dag-viz-graph svg g.node circle { - fill: #444; +#dag-viz-graph a, #dag-viz-graph a:hover { + text-decoration: none; } -#dag-viz-graph svg g.node rect { - fill: #C3EBFF; - stroke: #3EC0FF; - stroke-width: 1px; +#dag-viz-graph .label { + font-weight: normal; + text-shadow: none; } -#dag-viz-graph svg g.node.cached circle { - fill: #444; +#dag-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; } -#dag-viz-graph svg g.node.cached rect { - fill: #B3F5C5; - stroke: #56F578; +#dag-viz-graph svg g.cluster rect { stroke-width: 1px; } @@ -61,12 +50,23 @@ stroke-width: 1px; } -#dag-viz-graph svg.job g.cluster[class*="stage"] rect { +#dag-viz-graph svg.job g.cluster.skipped rect { + fill: #D6D6D6; + stroke: #B7B7B7; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage rect { fill: #FFFFFF; stroke: #FF99AC; stroke-width: 1px; } +#dag-viz-graph svg.job g.cluster.stage.skipped rect { + stroke: #ADADAD; + stroke-width: 1px; +} + #dag-viz-graph svg.job g#cross-stage-edges path { fill: none; } @@ -75,6 +75,20 @@ fill: #333; } +#dag-viz-graph svg.job g.cluster.skipped text { + fill: #666; +} + +#dag-viz-graph svg.job g.node circle { + fill: #444; +} + +#dag-viz-graph svg.job g.node.cached circle { + fill: #A3F545; + stroke: #52C366; + stroke-width: 2px; +} + /* Stage page specific styles */ #dag-viz-graph svg.stage g.cluster rect { @@ -83,7 +97,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.stage g.cluster[class*="stage"] rect { +#dag-viz-graph svg.stage g.cluster.stage rect { fill: #FFFFFF; stroke: #FFA6B6; stroke-width: 1px; @@ -97,11 +111,14 @@ fill: #333; } -#dag-viz-graph a, #dag-viz-graph a:hover { - text-decoration: none; +#dag-viz-graph svg.stage g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; } -#dag-viz-graph .label { - font-weight: normal; - text-shadow: none; +#dag-viz-graph svg.stage g.node.cached rect { + fill: #B3F5C5; + stroke: #52C366; + stroke-width: 2px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index ee48fd29a6432..aaeba5b1027c9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -57,9 +57,7 @@ var VizConstants = { stageSep: 40, graphPrefix: "graph_", nodePrefix: "node_", - stagePrefix: "stage_", - clusterPrefix: "cluster_", - stageClusterPrefix: "cluster_stage_" + clusterPrefix: "cluster_" }; var JobPageVizConstants = { @@ -133,9 +131,7 @@ function renderDagViz(forJob) { } // Render - var svg = graphContainer() - .append("svg") - .attr("class", jobOrStage); + var svg = graphContainer().append("svg").attr("class", jobOrStage); if (forJob) { renderDagVizForJob(svg); } else { @@ -185,23 +181,32 @@ function renderDagVizForJob(svgContainer) { var dot = metadata.select(".dot-file").text(); var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; - // Link each graph to the corresponding stage page (TODO: handle stage attempts) - var stageLink = $("#stage-" + stageId.replace(VizConstants.stagePrefix, "") + "-0") - .find("a") - .attr("href") + "&expandDagViz=true"; - var container = svgContainer - .append("a") - .attr("xlink:href", stageLink) - .append("g") - .attr("id", containerId); + var isSkipped = metadata.attr("skipped") == "true"; + var container; + if (isSkipped) { + container = svgContainer + .append("g") + .attr("id", containerId) + .attr("skipped", "true"); + } else { + // Link each graph to the corresponding stage page (TODO: handle stage attempts) + // Use the link from the stage table so it also works for the history server + var attemptId = 0 + var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) + .select("a") + .attr("href") + "&expandDagViz=true"; + container = svgContainer + .append("a") + .attr("xlink:href", stageLink) + .append("g") + .attr("id", containerId); + } // Now we need to shift the container for this stage so it doesn't overlap with // existing ones, taking into account the position and width of the last stage's // container. We do not need to do this for the first stage of this job. if (i > 0) { - var existingStages = svgContainer - .selectAll("g.cluster") - .filter("[class*=\"" + VizConstants.stageClusterPrefix + "\"]"); + var existingStages = svgContainer.selectAll("g.cluster.stage") if (!existingStages.empty()) { var lastStage = d3.select(existingStages[0].pop()); var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); @@ -214,6 +219,12 @@ function renderDagVizForJob(svgContainer) { // Actually render the stage renderDot(dot, container, true); + // Mark elements as skipped if appropriate. Unfortunately we need to mark all + // elements instead of the parent container because of CSS override rules. + if (isSkipped) { + container.selectAll("g").classed("skipped", true); + } + // Round corners on rectangles container .selectAll("rect") @@ -243,6 +254,9 @@ function renderDot(dot, container, forJob) { var renderer = new dagreD3.render(); preprocessGraphLayout(g, forJob); renderer(container, g); + + // Find the stage cluster and mark it for styling and post-processing + container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); } /* -------------------- * diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ad16becde85dd..6194c50ec8c7c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -352,10 +352,12 @@ private[spark] object UIUtils extends Logging {
-
+
+ {formattedBatchTime}