diff --git a/LICENSE b/LICENSE index 7950dd6ceb6d..66a2e8f13295 100644 --- a/LICENSE +++ b/LICENSE @@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5..cfa49b94c952 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.2.1 Title: R Frontend for Apache Spark Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ca45c6f9b0a9..44e39c4abb47 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -122,6 +122,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 88a138fd8eb1..a7b1e3b6ae32 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3642,3 +3642,33 @@ setMethod("checkpoint", df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) dataFrame(df) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 945676c7f10b..f8ae5526bc72 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -572,6 +572,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -1469,7 +1473,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") #' @rdname awaitTermination #' @export -setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 4db9cc30fb0c..306a9b867653 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' linear SVM Model +#' Linear SVM Model #' -#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package. +#' Currently only supports binary classification model with linear kernel. #' Users can print, make predictions on the produced model and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam The regularization parameter. +#' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. #' @param standardization Whether to standardize the training features before fitting the model. The coefficients @@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu new("LinearSVCModel", jobj = jobj) }) -# Predicted values based on an LinearSVCModel model +# Predicted values based on a LinearSVCModel model #' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @export @@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), predict_internal(object, newData) }) -# Get the summary of an LinearSVCModel +# Get the summary of a LinearSVCModel -#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}. #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list includes \code{coefficients} (coefficients of the fitted model), -#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), -#' \code{numFeatures} (number of features). +#' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method #' @export @@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"), function(object) { jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "coefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(unlist(coefficients), ncol = nCol) - intercept <- callJMethod(jobj, "intercept") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) numClasses <- callJMethod(jobj, "numClasses") numFeatures <- callJMethod(jobj, "numFeatures") - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients, intercept = intercept, - numClasses = numClasses, numFeatures = numFeatures) + list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures) }) # Save fitted LinearSVCModel to the input path diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index e353d2dd07c3..8390bd5e6de7 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -169,8 +169,10 @@ setMethod("isActive", #' immediately. #' #' @param x a StreamingQuery. -#' @param timeout time to wait in milliseconds -#' @return TRUE if query has terminated within the timeout period. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. #' @rdname awaitTermination #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method @@ -182,8 +184,12 @@ setMethod("isActive", #' @note experimental setMethod("awaitTermination", signature(x = "StreamingQuery"), - function(x, timeout) { - handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } }) #' stopQuery diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847b..b19556a1d57e 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -899,3 +899,19 @@ basenameSansExtFromUrl <- function(url) { isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } + +is_cran <- function() { + !identical(Sys.getenv("NOT_CRAN"), "true") +} + +is_windows <- function() { + .Platform$OS.type == "windows" +} + +hadoop_home_set <- function() { + !identical(Sys.getenv("HADOOP_HOME"), "") +} + +not_cran_or_windows_with_hadoop <- function() { + !is_cran() && (!is_windows() || hadoop_home_set()) +} diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index c9615c8d4faf..e2241e03b55f 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 4bc935c79eb0..ac706261999f 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa8..6e160fae1afe 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,9 +17,11 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 1d777ddb286d..00d684e1a49e 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,6 +17,8 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } @@ -25,3 +27,6 @@ test_that("sparkJars tag in SparkContext", { abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) + +message("--- End test (Windows) ", as.POSIXct(Sys.time(), tz = "GMT")) +message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e..00954fa31b0e 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,12 +18,14 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e620440..236cb3885445 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f4..254f8f522a70 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -26,6 +26,8 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf3..3d53bebab630 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c84711349111..f4893c4003f8 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,6 +18,8 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -55,8 +57,10 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -65,7 +69,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -73,12 +77,14 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - sc <- sparkR.sparkContext() # sc should get id 0 + skip_on_cran() + + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -96,7 +102,9 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.sparkContext() + skip_on_cran() + + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - sparkR.sparkContext() + skip_on_cran() + + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -161,14 +175,16 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() }) test_that("add and get file to be downloaded with Spark job on every node", { - sparkR.sparkContext() + skip_on_cran() + + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd..d7d9eeed1575 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R index 7348c893d0af..8b3b4f73de17 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a5..82e588dc460d 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib classification algorithms, except for tree-based algorithms") # Tests for MLlib classification algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.svmLinear", { + skip_on_cran() + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) @@ -38,9 +40,8 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) expect_true(all(abs(coefs - expected_coefs) < 0.1)) - expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) # Test prediction with string label prediction <- predict(model, training) @@ -50,15 +51,17 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # Test prediction with numeric label label <- c(0.0, 0.0, 0.0, 1.0, 1.0) @@ -128,15 +131,17 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # R code to reproduce the result. # nolint start @@ -223,6 +228,8 @@ test_that("spark.logit", { }) test_that("spark.mlp", { + skip_on_cran() + df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), @@ -243,19 +250,21 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + } # Test default parameter model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) @@ -284,22 +293,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +308,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { @@ -367,16 +363,18 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + } # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1661e987b730..e827e961ab4c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib clustering algorithms") # Tests for MLlib clustering algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.bisectingKmeans", { + skip_on_cran() + newIris <- iris newIris$Species <- NULL training <- suppressWarnings(createDataFrame(newIris)) @@ -53,18 +55,20 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } }) test_that("spark.gaussianMixture", { @@ -125,18 +129,20 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) + } }) test_that("spark.kmeans", { @@ -171,18 +177,20 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } # Test Kmeans on dataset that is sensitive to seed value col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) @@ -236,25 +244,29 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_true(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) - expect_equal(logPrior, stats2$logPrior) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) + } }) test_that("spark.lda with text input", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -297,6 +309,8 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R index c38f1133897d..4e10ca1e4f50 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib frequent pattern mining") # Tests for MLlib frequent pattern mining algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.fpGrowth", { data <- selectExpr(createDataFrame(data.frame(items = c( @@ -62,15 +62,17 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") - write.ml(model, modelPath, overwrite = TRUE) - loaded_model <- read.ml(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) - expect_equivalent( - itemsets, - collect(spark.freqItemsets(loaded_model))) + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) - unlink(modelPath) + unlink(modelPath) + } model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) expect_equal( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R index 6b1040db9305..cc8064f88d27 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib recommendation algorithms") # Tests for MLlib recommendation algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -37,29 +37,31 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - unlink(modelPath) + unlink(modelPath) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 3e9ad7719807..b05fdd350ca2 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -20,9 +20,11 @@ library(testthat) context("MLlib regression algorithms, except for tree-based algorithms") # Tests for MLlib regression algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -195,6 +197,8 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -222,6 +226,8 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -248,6 +254,8 @@ test_that("formula of glm", { }) test_that("glm and predict", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -292,6 +300,8 @@ test_that("glm and predict", { }) test_that("glm summary", { + skip_on_cran() + # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -341,6 +351,8 @@ test_that("glm summary", { }) test_that("glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -389,14 +401,16 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) + } }) test_that("spark.survreg", { @@ -438,17 +452,19 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + } # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/inst/tests/testthat/test_mllib_stat.R index beb148e7702f..1600833a5d03 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_stat.R +++ b/R/pkg/inst/tests/testthat/test_mllib_stat.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib statistics algorithms") # Tests for MLlib statistics algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.kstest", { data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e0802a9b02d1..923f535c34cd 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib tree-based algorithms") # Tests for MLlib tree-based algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -28,6 +28,8 @@ absoluteSparkPath <- function(x) { } test_that("spark.gbt", { + skip_on_cran() + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) @@ -44,21 +46,23 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification # label must be binary - GBTClassifier currently only supports binary classification. @@ -76,17 +80,19 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) df <- suppressWarnings(createDataFrame(iris2)) @@ -99,10 +105,12 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), - source = "libsvm") - model <- spark.gbt(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 692) + if (not_cran_or_windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) + } }) test_that("spark.randomForest", { @@ -136,21 +144,23 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification data <- suppressWarnings(createDataFrame(iris)) @@ -168,17 +178,19 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } # Test numeric response variable labelToIndex <- function(species) { @@ -203,10 +215,12 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) + if (not_cran_or_windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba469..52d4c93ed959 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,12 +33,14 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b72c801dd958..fb244e1d49e2 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { + skip_on_cran() + expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,6 +351,8 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -322,6 +374,8 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,6 +441,8 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1d..18320ea44b38 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R index f73fc6baecce..a40981c188f7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -18,6 +18,8 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { + skip_on_cran() + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a6c9a809ab1..d2d51915d72d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -61,7 +61,11 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session() +sparkSession <- if (not_cran_or_windows_with_hadoop()) { + sparkR.session(master = sparkRTestMaster) + } else { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -96,16 +100,26 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -194,6 +208,8 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -291,6 +307,8 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -303,54 +321,58 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "NA,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - # default "header" is false, inferSchema to handle "year" as "int" - df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") - expect_equal(count(df), 4) - expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) - expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), - sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) - - # since "year" is "int", let's skip the NA values - withoutna <- na.omit(df, how = "any", cols = "year") - expect_equal(count(withoutna), 3) - - unlink(csvPath) - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "Empty,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") - expect_equal(count(df2), 4) - withoutna2 <- na.omit(df2, how = "any", cols = "year") - expect_equal(count(withoutna2), 3) - expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) - - # writing csv file - csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") - write.df(df2, path = csvPath2, "csv", header = "true") - df3 <- read.df(csvPath2, "csv", header = "true") - expect_equal(nrow(df3), nrow(df2)) - expect_equal(colnames(df3), colnames(df2)) - csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) - expect_equal(colnames(df3), colnames(csv)) - - unlink(csvPath) - unlink(csvPath2) + if (not_cran_or_windows_with_hadoop()) { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) + } }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -405,6 +427,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -516,6 +540,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -528,6 +554,8 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -570,51 +598,55 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - # Test read.df - df <- read.df(jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test read.df with a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(jsonPath, "json", schema) - expect_is(df1, "SparkDataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Test loadDF - df2 <- loadDF(jsonPath, "json", schema) - expect_is(df2, "SparkDataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) - - # Test read.json - df <- read.json(jsonPath) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test write.df - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") - write.df(df, jsonPath2, "json", mode = "overwrite") - - # Test write.json - jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") - write.json(df, jsonPath3) - - # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) - expect_is(jsonDF1, "SparkDataFrame") - expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) - - unlink(jsonPath2) - unlink(jsonPath3) + if (not_cran_or_windows_with_hadoop()) { + # Test read.df + df <- read.df(jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(jsonPath, "json", schema) + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(jsonPath, "json", schema) + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(jsonPath) + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "SparkDataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "SparkDataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) + } }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -628,6 +660,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -642,24 +676,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( @@ -684,6 +721,8 @@ test_that( }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -696,33 +735,35 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - df <- read.df(jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(parquetPath2, "parquet") - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql("select * from table1")), 5) - expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - expect_true(dropTempView("table1")) - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql("select * from table1")), 2) - expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - expect_true(dropTempView("table1")) - - unlink(jsonPath2) - unlink(parquetPath2) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(parquetPath2, "parquet") + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + expect_true(dropTempView("table1")) + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + expect_true(dropTempView("table1")) + + unlink(jsonPath2) + unlink(parquetPath2) + } }) test_that("tableToDF() returns a new DataFrame", { @@ -737,6 +778,8 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -744,6 +787,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -754,6 +799,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -783,6 +830,8 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -795,6 +844,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -863,6 +914,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -902,14 +955,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - checkpointDir <- file.path(tempdir(), "cproot") - expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) - - setCheckpointDir(checkpointDir) - df <- read.json(jsonPath) - df <- checkpoint(df) - expect_is(df, "SparkDataFrame") - expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + if (not_cran_or_windows_with_hadoop()) { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + } }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { @@ -1267,45 +1322,47 @@ test_that("column calculation", { }) test_that("test HiveContext", { - setHiveContext(sc) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - createTable("people", source = "json", schema = schema) - df <- read.df(jsonPathNa, "json", schema) - insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - sql("DROP TABLE people") - - df <- createTable("json", jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - df2 <- sql("select * from json") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "json2", "json", "append", path = jsonPath2) - df3 <- sql("select * from json2") - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - unlink(jsonPath2) - - hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "hivetestbl", path = hivetestDataPath) - df4 <- sql("select * from hivetestbl") - expect_is(df4, "SparkDataFrame") - expect_equal(count(df4), 3) - unlink(hivetestDataPath) - - parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) - df5 <- sql("select * from parquetest") - expect_is(df5, "SparkDataFrame") - expect_equal(count(df5), 3) - unlink(parquetDataPath) - - unsetHiveContext() + if (not_cran_or_windows_with_hadoop()) { + setHiveContext(sc) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + df2 <- sql("select * from json") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "json2", "json", "append", path = jsonPath2) + df3 <- sql("select * from json2") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "hivetestbl", path = hivetestDataPath) + df4 <- sql("select * from hivetestbl") + expect_is(df4, "SparkDataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) + df5 <- sql("select * from parquetest") + expect_is(df5, "SparkDataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) + + unsetHiveContext() + } }) test_that("column operators", { @@ -1317,6 +1374,8 @@ test_that("column operators", { }) test_that("column functions", { + skip_on_cran() + c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) @@ -1638,6 +1697,8 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { }) test_that("group by, agg functions", { + skip_on_cran() + df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -1793,6 +1854,8 @@ test_that("filter() on a DataFrame", { }) test_that("join(), crossJoin() and merge() on a DataFrame", { + skip_on_cran() + df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", @@ -1890,6 +1953,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { @@ -2049,6 +2124,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2070,6 +2147,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2085,37 +2164,41 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - df <- read.df(jsonPath, "json") - # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(parquetPath, "parquet") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.parquet(df, parquetPath2) - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) - expect_is(parquetDF, "SparkDataFrame") - expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) - - unlink(parquetPath2) - unlink(parquetPath3) - unlink(parquetPath4) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(parquetPath, "parquet") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "SparkDataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) + expect_is(parquetDF2, "SparkDataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) + } }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2133,6 +2216,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2154,6 +2239,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2387,6 +2474,8 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2395,6 +2484,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2617,6 +2708,7 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { + skip_on_cran() df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) @@ -2638,6 +2730,8 @@ test_that("dapplyCollect() on DataFrame with a binary column", { }) test_that("repartition by columns on DataFrame", { + skip_on_cran() + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -2676,6 +2770,8 @@ test_that("repartition by columns on DataFrame", { }) test_that("coalesce, repartition, numPartitions", { + skip_on_cran() + df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) expect_equal(getNumPartitions(coalesce(df, 3)), 3) @@ -2695,6 +2791,8 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { + skip_on_cran() + df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) @@ -2812,6 +2910,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -2845,6 +2945,8 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { }) test_that("randomSplit", { + skip_on_cran() + num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) weights <- c(2, 3, 5) @@ -2891,6 +2993,8 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -2906,6 +3010,8 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -2932,6 +3038,8 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3056,6 +3164,8 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory + # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 03b1bd3dc1f4..b20b4312fbaa 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -21,7 +21,7 @@ context("Structured Streaming") # Tests for Structured Streaming functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") if (.Platform$OS.type == "windows") { @@ -47,29 +47,37 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) }) test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -82,6 +90,8 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { + skip_on_cran() + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -92,6 +102,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") @@ -107,6 +118,8 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { + skip_on_cran() + c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -116,6 +129,8 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { + skip_on_cran() + # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -124,6 +139,8 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, @@ -131,7 +148,7 @@ test_that("Terminated by error", { expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), NA) - expect_error(awaitTermination(q, 1), + expect_error(awaitTermination(q, 5 * 1000), paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", " 'maxFilesPerTrigger', must be a positive integer).*")) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d..c00723ba31f4 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,10 +30,12 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e939..e8a961cb3e87 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,12 +18,14 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665..01614716afa3 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,11 +18,12 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -167,6 +169,7 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { + skip_on_cran() method <- "getSQLDataType" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "unknown"), @@ -177,6 +180,8 @@ test_that("captureJVMException", { }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) @@ -237,3 +242,6 @@ test_that("basenameSansExtFromUrl", { }) sparkR.session.stop() + +message("--- End test (utils) ", as.POSIXct(Sys.time(), tz = "GMT")) +message("elapsed ", (proc.time() - timer_ptm)[3]) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 29812f872c78..f0bef4f6d266 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,6 +21,12 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} +message("--- Start test ", as.POSIXct(Sys.time(), tz = "GMT")) +timer_ptm <- proc.time() + # Setup global test environment # Install Spark first to set SPARK_HOME install.spark() @@ -31,4 +37,9 @@ sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index a6ff650c33fe..031c3bc41f7c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,8 +46,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() +sparkR.session(master = "local[1]") ``` -```{r, message=FALSE, results="hide"} +```{r, eval=FALSE} sparkR.session() ``` @@ -65,7 +66,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -182,7 +183,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +233,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -364,7 +365,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -390,7 +391,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -443,20 +444,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -505,6 +506,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -653,6 +658,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -707,7 +713,7 @@ summary(tweedieGLM1) ``` We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: ```{r} -tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 1.2, link.power = 0.0) summary(tweedieGLM2) ``` @@ -760,7 +766,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -800,7 +806,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -831,9 +837,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -881,9 +887,9 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -891,7 +897,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -901,11 +907,42 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). @@ -930,7 +967,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -952,6 +989,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes @@ -962,19 +1065,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76d..29764f48bd15 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad0..f4d13b8515cd 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,6 +48,9 @@ install: build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +59,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c..da7b0c9d1b93 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/bin/spark-class b/bin/spark-class index 77ea40cc3794..65d3b9612909 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 9faa7d65f83e..f6157f42843e 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" ( - set RUNNER="%JAVA_HOME%\bin\java" + set RUNNER=%JAVA_HOME%\bin\java ) else ( where /q "%RUNNER%" if ERRORLEVEL 1 ( diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c06..7577253dd039 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index ee367f9998db..ad8e8b44d201 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import scala.Tuple2; + import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -94,6 +96,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { return nextChunk; } + @Override + public ManagedBuffer openStream(String streamChunkId) { + Tuple2 streamIdAndChunkId = parseStreamChunkId(streamChunkId); + return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2); + } + + public static String genStreamChunkId(long streamId, int chunkId) { + return String.format("%d_%d", streamId, chunkId); + } + + public static Tuple2 parseStreamChunkId(String streamChunkId) { + String[] array = streamChunkId.split("_"); + assert array.length == 2: + "Stream id and chunk index should be specified when open stream for fetching block."; + long streamId = Long.valueOf(array[0]); + int chunkIndex = Integer.valueOf(array[1]); + return new Tuple2<>(streamId, chunkIndex); + } + @Override public void connectionTerminated(Channel channel) { // Close all streams which have been associated with the channel. diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9..558864ae4faa 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6daf9609d76d..c0f1da50f5e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 2c5827bf7dc5..269fa72dad5f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -86,14 +87,16 @@ public void fetchBlocks( int port, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + File[] shuffleFiles) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start(); + new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, + shuffleFiles).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 35f69fe35c94..5f428759252a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,19 +17,28 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.util.Arrays; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.TransportConf; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and @@ -48,6 +57,8 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; + private TransportConf transportConf = null; + private File[] shuffleFiles = null; private StreamHandle streamHandle = null; @@ -56,12 +67,20 @@ public OneForOneBlockFetcher( String appId, String execId, String[] blockIds, - BlockFetchingListener listener) { + BlockFetchingListener listener, + TransportConf transportConf, + File[] shuffleFiles) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); + this.transportConf = transportConf; + if (shuffleFiles != null) { + this.shuffleFiles = shuffleFiles; + assert this.shuffleFiles.length == blockIds.length: + "Number of shuffle files should equal to blocks"; + } } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -100,7 +119,12 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + if (shuffleFiles != null) { + client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), + new DownloadCallback(shuffleFiles[i], i)); + } else { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } } } catch (Exception e) { logger.error("Failed while starting block fetches after success", e); @@ -126,4 +150,38 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } } + + private class DownloadCallback implements StreamCallback { + + private WritableByteChannel channel = null; + private File targetFile = null; + private int chunkIndex; + + public DownloadCallback(File targetFile, int chunkIndex) throws IOException { + this.targetFile = targetFile; + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + this.chunkIndex = chunkIndex; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + channel.write(buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index f72ab40690d0..978ff5a2a869 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.Closeable; +import java.io.File; /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { @@ -40,5 +41,6 @@ public abstract void fetchBlocks( int port, String execId, String[] blockIds, - BlockFetchingListener listener); + BlockFetchingListener listener, + File[] shuffleFiles); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c0e170e5b935..0c054fc5db8f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -204,7 +204,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e47a72c9d16c..4d48b1897038 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,8 +88,6 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b8ae04eefb97..d1d8f5b4e188 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }); + }, null); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); @@ -216,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 3e51fea3cf0e..61d82214e7d3 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -46,8 +46,13 @@ import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; public class OneForOneBlockFetcherSuite { + + private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + @Test public void testFetchOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); @@ -126,7 +131,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap { diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd4446..de66617d2fa2 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml @@ -87,6 +87,9 @@ *:* + + org.scala-lang:scala-library + @@ -98,7 +101,7 @@ - + com.fasterxml.jackson ${spark.shade.packageName}.com.fasterxml.jackson diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index c7620d0fe128..fd50e3a4bfb9 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.List; import java.util.Map; @@ -340,9 +339,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -350,7 +349,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -363,17 +362,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -381,7 +384,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b66..076d98af834d 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4..e74d84a5b3b9 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e715..76783abe36a2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1321b8318115..aca6fca00c48 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,8 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it _unaligned = true; } else { try { diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 55cb094b4af4..2ecb4f1464a4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -15,6 +15,6 @@ # limitations under the License. # -spark.mesos.executor.docker.image: +spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 94bd2c477a35..b7c985ace69c 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,7 +34,6 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) diff --git a/core/pom.xml b/core/pom.xml index 24ce36deeb16..254a9b9ac318 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index aa0b37323132..761ba9de659d 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import javax.annotation.concurrent.GuardedBy; import java.io.IOException; +import java.nio.channels.ClosedByInterruptException; import java.util.Arrays; import java.util.ArrayList; import java.util.BitSet; @@ -155,7 +156,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } @@ -183,6 +185,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { break; } } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + c, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); throw new OutOfMemoryError("error while calling spill() on " + c + " : " @@ -200,6 +206,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { Utils.bytesToString(released), consumer); got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + consumer, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 8a1771848dee..2fde5c300f07 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -422,17 +422,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( - spillInputChannelPositions[i], - bytesToTransfer, - mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; - } + Utils.copyFileStreamNIO( + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 930a0698928d..d430d8c5fb35 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -26,7 +26,6 @@ function getThreadDumpEnabled() { } function formatStatus(status, type) { - if (type !== 'display') return status; if (status) { return "Active" } else { @@ -253,10 +252,14 @@ $(document).ready(function () { var deadTotalBlacklisted = 0; response.forEach(function (exec) { - exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; - exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; - exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; - exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; }); response.forEach(function (exec) { @@ -264,10 +267,10 @@ $(document).ready(function () { allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; - allOnHeapMemoryUsed += exec.onHeapMemoryUsed; - allOnHeapMaxMemory += exec.maxOnHeapMemory; - allOffHeapMemoryUsed += exec.offHeapMemoryUsed; - allOffHeapMaxMemory += exec.maxOffHeapMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -286,10 +289,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; - activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; - activeOnHeapMaxMemory += exec.maxOnHeapMemory; - activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; - activeOffHeapMaxMemory += exec.maxOffHeapMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -308,10 +311,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; - deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; - deadOnHeapMaxMemory += exec.maxOnHeapMemory; - deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; - deadOffHeapMaxMemory += exec.maxOffHeapMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -413,7 +416,6 @@ $(document).ready(function () { }, {data: 'hostPort'}, {data: 'isActive', render: function (data, type, row) { - if (type !== 'display') return data; if (row.isBlacklisted) return "Blacklisted"; else return formatStatus (data, type); } @@ -431,10 +433,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.onHeapMemoryUsed; + return row.memoryMetrics.usedOnHeapStorageMemory; else - return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOnHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('on_heap_memory') @@ -443,10 +445,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.offHeapMemoryUsed; + return row.memoryMetrics.usedOffHeapStorageMemory; else - return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOffHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('off_heap_memory') @@ -488,24 +490,20 @@ $(document).ready(function () { {data: 'totalInputBytes', render: formatBytes}, {data: 'totalShuffleRead', render: formatBytes}, {data: 'totalShuffleWrite', render: formatBytes}, - {data: 'executorLogs', render: formatLogsCells}, + {name: 'executorLogsCol', data: 'executorLogs', render: formatLogsCells}, { + name: 'threadDumpCol', data: 'id', render: function (data, type) { return type === 'display' ? ("Thread Dump" ) : data; } } ], - "columnDefs": [ - { - "targets": [ 16 ], - "visible": getThreadDumpEnabled() - } - ], "order": [[0, "asc"]] }; var dt = $(selector).DataTable(conf); - dt.column(15).visible(logsExist(response)); + dt.column('executorLogsCol:name').visible(logsExist(response)); + dt.column('threadDumpCol:name').visible(getThreadDumpEnabled()); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 42e2d9abdeb5..bfe31aae555b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -20,47 +20,47 @@ - + App ID - + App Name - + Attempt ID - + Started - - + + Completed - + Duration - + Spark User - + Last Updated - + Event Log @@ -73,11 +73,11 @@ {{#attempts}} {{attemptId}} {{startTime}} - {{endTime}} + {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 54810edaf146..5ec1ce15a212 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -120,6 +120,9 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } @@ -174,6 +177,13 @@ $(document).ready(function() { } } + if (requestedIncomplete) { + var completedCells = document.getElementsByClassName("completedColumn"); + for (i = 0; i < completedCells.length; i++) { + completedCells[i].style.display='none'; + } + } + var durationCells = document.getElementsByClassName("durationClass"); for (i = 0; i < durationCells.length; i++) { var timeInMilliseconds = parseInt(durationCells[i].title); @@ -185,7 +195,7 @@ $(document).ready(function() { } $(selector).DataTable(conf); - $('#hisotry-summary [data-toggle="tooltip"]').tooltip(); + $('#history-summary [data-toggle="tooltip"]').tooltip(); }); }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index ff241470f32d..9960d5c34d1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -207,8 +207,8 @@ sorttable = { hasInputs = (typeof node.getElementsByTagName == 'function') && node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { + + if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) { return node.getAttribute("sorttable_customkey"); } else if (typeof node.textContent != 'undefined' && !hasInputs) { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 261b3329a7b9..fcc72ff49276 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index a50600f1488c..089969398801 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -261,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S private def getImpl(timeout: Duration): T = { // This will throw TimeoutException on timeout: - Await.ready(futureAction, timeout) + ThreadUtils.awaitReady(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) case scala.util.Failure(exception) => diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 99efc4893fda..1a2443f7ee78 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** @@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus @@ -1800,40 +1801,39 @@ class SparkContext(config: SparkConf) extends Logging { * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { + def addJarFile(file: File): String = { + try { + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") + } + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") + } + env.rpcEnv.fileServer.addJar(file) + } catch { + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) + null + } + } + if (path == null) { logWarning("null specified as parameter to addJar") } else { - var key = "" - if (path.contains("\\")) { + val key = if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.rpcEnv.fileServer.addJar(new File(path)) + addJarFile(new File(path)) } else { val uri = new URI(path) // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies Utils.validateURL(uri) - key = uri.getScheme match { + uri.getScheme match { // A JAR file which exists only on the driver node - case null | "file" => - try { - val file = new File(uri.getPath) - if (!file.exists()) { - throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") - } - if (file.isDirectory) { - throw new IllegalArgumentException( - s"Directory ${file.getAbsoluteFile} is not allowed for addJar") - } - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case NonFatal(e) => - logError(s"Failed to add $path to Spark environment", e) - null - } + case null | "file" => addJarFile(new File(uri.getPath)) // A JAR file which exists locally on every worker node - case "local" => - "file:" + uri.getPath - case _ => - path + case "local" => "file:" + uri.getPath + case _ => path } } if (key != null) { @@ -1938,6 +1938,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c96aa0..01d8973e1bb0 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8..fa35e4568819 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable @@ -34,6 +34,16 @@ private[deploy] object DeployMessages { // Worker to Master + /** + * @param id the worker id + * @param host the worker host + * @param port the worker post + * @param worker the worker endpoint ref + * @param cores the core number of worker + * @param memory the memory size of worker + * @param workerWebUiUrl the worker Web UI address + * @param masterAddress the master address used by the worker to connect + */ case class RegisterWorker( id: String, host: String, @@ -41,7 +51,8 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - workerWebUiUrl: String) + workerWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage { Utils.checkHost(host, "Required hostname") assert (port > 0) @@ -80,8 +91,16 @@ private[deploy] object DeployMessages { sealed trait RegisterWorkerResponse - case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage - with RegisterWorkerResponse + /** + * @param master the master ref + * @param masterWebUiUrl the master Web UI address + * @param masterAddress the master address used by the worker to connect. It should be + * [[RegisterWorker.masterAddress]]. + */ + case class RegisteredWorker( + master: RpcEndpointRef, + masterWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 050778a895c0..7d356e8fc1c0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -92,6 +92,9 @@ private[deploy] object RPackageUtils extends Logging { * Exposed for testing. */ private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + if (jar.getManifest == null) { + return false + } val manifest = jar.getManifest.getMainAttributes manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index bae7a3f307f5..6afe58bff522 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -23,11 +23,13 @@ import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -142,14 +144,29 @@ class SparkHadoopUtil extends Logging { * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will * return the bytes read on r since t. - * - * @return None if the required method can't be found. */ private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { - val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) - val f = () => threadStats.map(_.getBytesRead).sum - val baselineBytesRead = f() - () => f() - baselineBytesRead + val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum + val baseline = (Thread.currentThread().getId, f()) + + /** + * This function may be called in both spawned child threads and parent task thread (in + * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics. + * So we need a map to track the bytes read from the child threads and parent thread, + * summing them together to get the bytes read of this task. + */ + new Function0[Long] { + private val bytesReadMap = new mutable.HashMap[Long, Long]() + + override def apply(): Long = { + bytesReadMap.synchronized { + bytesReadMap.put(Thread.currentThread().getId, f()) + bytesReadMap.map { case (k, v) => + v - (if (k == baseline._1) baseline._2 else 0) + }.sum + } + } + } } /** @@ -353,6 +370,28 @@ class SparkHadoopUtil extends Logging { } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 77005aa9040b..c60a2a1706d5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.nio.file.Files import java.security.PrivilegedExceptionAction import java.text.ParseException @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } + // In client mode, download remote files. + if (deployMode == CLIENT) { + val hadoopConf = new HadoopConfiguration() + args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull + args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull + args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local. // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils { .mkString(",") if (merged == "") null else merged } + + /** + * Download a list of remote files to temp local files. If the file is local, the original file + * will be returned. + * @param fileList A comma separated file list. + * @return A comma separated local files list. + */ + private[deploy] def downloadFileList( + fileList: String, + hadoopConf: HadoopConfiguration): String = { + require(fileList != null, "fileList cannot be null.") + fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + } + + /** + * Download a file from the remote to a local temporary directory. If the input path points to + * a local path, returns it with no operation. + */ + private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + require(path != null, "path cannot be null.") + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "file" | "local" => + path + + case _ => + val fs = FileSystem.get(uri, hadoopConf) + val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") + // scalastyle:on println + fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) + Utils.resolveURI(tmpFile.getAbsolutePath).toString + } + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b5..6d8758a3d3b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9012736bc274..f4235df24512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0e7a6c24d4fa..af1471763340 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + // stripXSS is called first to remove suspicious characters used in XSS attacks val requestedIncomplete = - Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean + Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5..d9c8fda99ef9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fe..96b53c624232 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,7 +231,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -243,7 +244,7 @@ private[deploy] class Master( workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress)) schedule() } else { val workerAddress = worker.endpoint.address diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 946a92882141..94ff81c1a68e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val appId = UIUtils.stripXSS(request.getParameter("appId")) val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) @@ -83,7 +84,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index e722a24d4a89..ce71300e9097 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val id = Option(request.getParameter("id")) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val killFlag = + Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean + val id = Option(UIUtils.stripXSS(request.getParameter("id"))) if (id.isDefined && killFlag) { action(id.get) } @@ -252,7 +254,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => if (w.isAlive()) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e878c10183f6..58a181128eb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -57,7 +57,8 @@ private[deploy] class DriverRunner( @volatile private[worker] var finalException: Option[Exception] = None // Timeout to wait for when trying to terminate a driver. - private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000 + private val DRIVER_TERMINATE_TIMEOUT_MS = + conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s") // Decoupled for testing def setClock(_clock: Clock): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373d..ca9243e39c0a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -99,6 +99,20 @@ private[deploy] class Worker( private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None + + /** + * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker + * will just use the address received from Master. + */ + private val preferConfiguredMasterAddress = + conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false) + /** + * The master address to connect in case of failure. When the connection is broken, worker will + * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when + * a master is restarted or takes over leadership, it will be an address sent from master, which + * may not be in `masterRpcAddresses`. + */ + private var masterAddressToConnect: Option[RpcAddress] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" private var workerWebUiUrl: String = "" @@ -196,10 +210,19 @@ private[deploy] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { + /** + * Change to use the new master. + * + * @param masterRef the new master ref + * @param uiUrl the new master Web UI address + * @param masterAddress the new master address which the worker should use to connect in case of + * failure + */ + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl + masterAddressToConnect = Some(masterAddress) master = Some(masterRef) connected = true if (conf.getBoolean("spark.ui.reverseProxy", false)) { @@ -266,7 +289,8 @@ private[deploy] class Worker( if (registerMasterFutures != null) { registerMasterFutures.foreach(_.cancel(true)) } - val masterAddress = masterRef.address + val masterAddress = + if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { @@ -342,15 +366,27 @@ private[deploy] class Worker( } private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) + masterEndpoint.send(RegisterWorker( + workerId, + host, + port, + self, + cores, + memory, + workerWebUiUrl, + masterEndpoint.address)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { msg match { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) => + if (preferConfiguredMasterAddress) { + logInfo("Successfully registered with master " + masterAddress.toSparkURL) + } else { + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + } registered = true - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterAddress) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { self.send(SendHeartbeat) @@ -419,7 +455,7 @@ private[deploy] class Worker( case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterRef.address) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) @@ -561,7 +597,8 @@ private[deploy] class Worker( } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (master.exists(_.address == remoteAddress)) { + if (master.exists(_.address == remoteAddress) || + masterAddressToConnect.exists(_ == remoteAddress)) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 80dc9bf8779d..2f5a5642d3ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 + // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => @@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 83469c5ff060..d54dd2d46482 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -84,7 +86,20 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) // Pool used for threads that supervise task killing / cancellation private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") @@ -410,6 +425,7 @@ private[spark] class Executor( } } + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { @@ -432,7 +448,8 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acda..a3ce3d1ccc5e 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 89aeea493908..f8139b706a7c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -244,8 +244,8 @@ package object config { ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + "driver and executor environments contain sensitive information. When this regex matches " + - "a property, its value is redacted from the environment UI and various logs like YARN " + - "and event logs.") + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") .regexConf .createWithDefault("(?i)secret|password".r) @@ -272,4 +272,25 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockThreshold") + .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + + "record the size accurately if it's above this config. This helps to prevent OOM by " + + "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(100 * 1024 * 1024) + + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = + ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") + .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + "above this threshold. This is to avoid a giant request takes too much memory.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("200m") } diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 8f83668d7902..b3f8bfe8b1d4 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -46,5 +46,5 @@ trait BlockDataManager { /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ - def releaseLock(blockId: BlockId): Unit + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index cb9d389dd7ea..6860214c7fe3 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import java.io.Closeable +import java.io.{Closeable, File} import java.nio.ByteBuffer import scala.concurrent.{Future, Promise} @@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ret.flip() result.success(new NioManagedBuffer(ret)) } - }) + }, shuffleFiles = null) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df702..305fd9a6de10 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b75e91b66096..b13a9c681e54 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, + transportConf, shuffleFiles).start() } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4bf8ecc38354..76ea8b86c53d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -251,7 +251,13 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener { context => + // Update the bytes read before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + closeIfNeeded() + } + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ce3a9a2a1e2a..482875e6c1ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -191,7 +191,13 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) + context.addTaskCompletionListener { context => + // Update the bytesRead before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + close() + } + private var havePair = false private var recordsSinceMetricsUpdate = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1..63a87e7f09d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index e0a29b48314f..37c67cee55f9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428..ab72addb2466 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aab177f257a8..35f6b365eca8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -618,12 +618,7 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`, - // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that - // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's - // safe to pass in null here. For more detail, see SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - waiter.completionFuture.ready(Duration.Inf)(awaitPermission) + ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) waiter.completionFuture.value.get match { case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index aecb3a980e7c..a7dbf87915b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -252,11 +252,17 @@ private[spark] class EventLoggingListener( private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { - // "Spark Properties" entry will always exist because the map is always populated with it. - val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties")) - val redactedEnvironmentDetails = event.environmentDetails + - ("Spark Properties" -> redactedProps) - SparkListenerEnvironmentUpdate(redactedEnvironmentDetails) + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index b2e9a97129f0..048e0d018659 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.roaringbitmap.RoaringBitmap +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus( } /** - * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, + * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger + * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks, * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty - * @param avgSize average size of the non-empty blocks + * @param avgSize average size of the non-empty and non-huge blocks + * @param hugeBlockSizes sizes of huge blocks by their reduceId. */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long) + private[this] var avgSize: Long, + @transient private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization - require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, + require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { 0 } else { - avgSize + hugeBlockSizes.get(reduceId) match { + case Some(size) => MapStatus.decompressSize(size) + case None => avgSize + } } } @@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private ( loc.writeExternal(out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) + out.writeInt(hugeBlockSizes.size) + hugeBlockSizes.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private ( emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() + val count = in.readInt() + val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until count).foreach { _ => + val block = in.readInt() + val size = in.readByte() + hugeBlockSizesArray += Tuple2(block, size) + } + hugeBlockSizes = hugeBlockSizesArray.toMap } } @@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus { // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() while (i < totalNumBlocks) { - var size = uncompressedSizes(i) + val size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 - totalSize += size + // Huge blocks are not included in the calculation for average size, thus size for smaller + // blocks is more accurate. + if (size < threshold) { + totalSize += size + } else { + hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + } } else { emptyBlocks.add(i) } @@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7fd2918960cd..7767ef1803a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } @@ -182,14 +189,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4eedaaea6119..dc82bb770472 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -413,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -487,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -524,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -589,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index ba3e0e395e95..2fbac79a2305 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator @@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66..f17b63775482 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d159b9450ef5..56d8e51732ff 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -76,10 +76,13 @@ class ExecutorSummary private[spark]( val isBlacklisted: Boolean, val maxMemory: Long, val executorLogs: Map[String, String], - val onHeapMemoryUsed: Option[Long], - val offHeapMemoryUsed: Option[Long], - val maxOnHeapMemory: Option[Long], - val maxOffHeapMemory: Option[Long]) + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 3db59837fbeb..7064872ec1c7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging { /** * Release a lock on the given block. + * In case a TaskContext is not propagated properly to all child threads for the task, we fail to + * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock. + * + * See SPARK-18406 for more discussion of this issue. */ - def unlock(blockId: BlockId): Unit = synchronized { - logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized { + val taskId = taskAttemptId.getOrElse(currentTaskAttemptId) + logTrace(s"Task $taskId releasing lock for $blockId") val info = get(blockId).getOrElse { throw new IllegalStateException(s"Block $blockId not found") } if (info.writerTask != BlockInfo.NO_WRITER) { info.writerTask = BlockInfo.NO_WRITER - writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + writeLocksByTask.removeBinding(taskId, blockId) } else { assert(info.readerCount > 0, s"Block $blockId is not locked for reading") info.readerCount -= 1 - val countsForTask = readLocksByTask(currentTaskAttemptId) + val countsForTask = readLocksByTask(taskId) val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 assert(newPinCountForTask >= 0, - s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + s"Task $taskId release lock on block $blockId more times than it acquired it") } notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06..5f067191070e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,14 +23,12 @@ import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import com.google.common.io.ByteStreams - import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging @@ -41,7 +39,6 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -337,7 +334,7 @@ private[spark] class BlockManager( val task = asyncReregisterTask if (task != null) { try { - Await.ready(task, Duration.Inf) + ThreadUtils.awaitReady(task, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for async. reregistration", t) @@ -504,6 +501,7 @@ private[spark] class BlockManager( case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") + val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId()) if (level.useMemory && memoryStore.contains(blockId)) { val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get @@ -511,7 +509,12 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream( blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) } - val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + // We need to capture the current taskId in case the iterator completion is triggered + // from a different thread which does not have TaskContext set; see SPARK-18406 for + // discussion. + val ci = CompletionIterator[Any, Iterator[Any]](iter, { + releaseLock(blockId, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { val diskData = diskStore.getBytes(blockId) @@ -528,8 +531,9 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, - releaseLockAndDispose(blockId, diskData)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, { + releaseLockAndDispose(blockId, diskData, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -707,10 +711,13 @@ private[spark] class BlockManager( } /** - * Release a lock on the given block. + * Release a lock on the given block with explicit TID. + * The param `taskAttemptId` should be passed in case we can't get the correct TID from + * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child + * thread. */ - def releaseLock(blockId: BlockId): Unit = { - blockInfoManager.unlock(blockId) + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = { + blockInfoManager.unlock(blockId, taskAttemptId) } /** @@ -912,7 +919,7 @@ private[spark] class BlockManager( if (level.replication > 1) { // Wait for asynchronous replication to finish try { - Await.ready(replicationFuture, Duration.Inf) + ThreadUtils.awaitReady(replicationFuture, Duration.Inf) } catch { case NonFatal(t) => throw new Exception("Error occurred while waiting for replication to finish", t) @@ -1463,8 +1470,11 @@ private[spark] class BlockManager( } } - def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { - blockInfoManager.unlock(blockId) + def releaseLockAndDispose( + blockId: BlockId, + data: BlockData, + taskAttemptId: Option[Long] = None): Unit = { + releaseLock(blockId, taskAttemptId) data.dispose() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 467c3e0e6b51..6f85b9e4d6c7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -497,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -520,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f8906117638b..bded3a1e4eb5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{File, InputStream, IOException} import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy @@ -52,6 +52,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] @@ -63,6 +64,7 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, + maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { @@ -129,6 +131,12 @@ final class ShuffleBlockFetcherIterator( @GuardedBy("this") private[this] var isZombie = false + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + val shuffleFilesSet = mutable.HashSet[File]() + initialize() // Decrements the buffer reference count. @@ -163,6 +171,11 @@ final class ShuffleBlockFetcherIterator( case _ => } } + shuffleFilesSet.foreach { file => + if (!file.delete()) { + logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); + } + } } private[this] def sendRequest(req: FetchRequest) { @@ -175,33 +188,46 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) - val address = req.address - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { - // Only add the buffer to results queue if the iterator is not zombie, - // i.e. cleanup() has not been called yet. - ShuffleBlockFetcherIterator.this.synchronized { - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) - logDebug("remainingBlocks: " + remainingBlocks) - } + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) } - logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } - override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) - } + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } - ) + } + + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is + // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch + // the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + val shuffleFiles = blockIds.map { _ => + blockManager.diskBlockManager.createTempLocalBlock()._2 + }.toArray + shuffleFilesSet ++= shuffleFiles + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, shuffleFiles) + } else { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 1b30d4fa93bc..ac60f795915a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 8f0d181fc8fe..e9694fdbca2d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi +@deprecated("This class may be removed or made private in a future release.", "2.2.0") class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index bdbdba578085..edf328b5ae53 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 7d31ac54a717..bf4cf79e9faa 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -117,7 +117,7 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", + sparkUser = getSparkUser, completed = false )) )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e53d6907bc40..4bc7fb6185e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = @@ -446,7 +450,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } @@ -527,4 +531,21 @@ private[spark] object UIUtils extends Logging { origHref } } + + /** + * Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks + * + * For more information about XSS testing: + * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and + * https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001) + */ + def stripXSS(requestParameter: String): String = { + if (requestParameter == null) { + null + } else { + // Remove new lines and single quotes, followed by escaping HTML version 4.0 + StringEscapeUtils.escapeHtml4( + NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 70b3ffd95e60..8c18464e6477 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en * A SparkListener that prepares information to be displayed on the EnvironmentTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 6ce3f511e89c..7b211ea5199c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).map { executorId => + val executorId = + Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId => UIUtils.decodeURLParameter(executorId) }.getOrElse { throw new IllegalArgumentException(s"Missing executorId parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 0a3c63d14ca8..b7cbed468517 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive @@ -114,10 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem - val onHeapMemUsed = status.onHeapMemUsed - val offHeapMemUsed = status.offHeapMemUsed - val maxOnHeapMem = status.maxOnHeapMem - val maxOffHeapMem = status.maxOffHeapMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -142,10 +148,7 @@ private[spark] object ExecutorsPage { taskSummary.isBlacklisted, maxMem, taskSummary.executorLogs, - onHeapMemUsed, - offHeapMemUsed, - maxOnHeapMem, - maxOffHeapMem + memoryMetrics ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 03851293eb2f..aabf6e0c63c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -62,6 +62,7 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 18be0870746e..a0fd29c22ddc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { jobTag: String, jobs: Seq[JobUIData], killEnabled: Boolean): Seq[Node] = { - val allParameters = request.getParameterMap.asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" - val parameterJobPage = request.getParameter(jobTag + ".page") - val parameterJobSortColumn = request.getParameter(jobTag + ".sort") - val parameterJobSortDesc = request.getParameter(jobTag + ".desc") - val parameterJobPageSize = request.getParameter(jobTag + ".pageSize") - val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page")) + val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort")) + val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc")) + val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize")) + val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize")) val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1) val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 3131c4a1eb7d..9fb011a049b7 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val listener = parent.jobProgresslistener listener.synchronized { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val jobId = parameterId.toInt diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f78db5ab80d1..7370f9feb68c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._ * updating the internal data structures concurrently. */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Define a handful of type aliases so that data structures' types can serve as documentation. @@ -328,13 +329,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val metrics = TaskMetrics.empty val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo, Some(metrics))) + stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo)) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -404,7 +404,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } - val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info, None)) + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info)) taskData.updateTaskInfo(info) taskData.updateTaskMetrics(taskMetrics) taskData.errorMessage = errorMessage diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 620c54c2dc0a..cc173381879a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all jobs in the given SparkContext. */ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val jobId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) jobId.foreach { id => if (jobProgresslistener.activeJobs.contains(id)) { sc.foreach(_.cancelJob(id)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 8ee70d27cc09..b164f32b62e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = Option(request.getParameter("poolname")).map { poolname => + // stripXSS is called first to remove suspicious characters used in XSS attacks + val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => UIUtils.decodeURLParameter(poolname) }.getOrElse { throw new IllegalArgumentException(s"Missing poolname parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 19325a2dc916..6b3dadc33331 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterAttempt = request.getParameter("attempt") + val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - val parameterTaskPage = request.getParameter("task.page") - val parameterTaskSortColumn = request.getParameter("task.sort") - val parameterTaskSortDesc = request.getParameter("task.desc") - val parameterTaskPageSize = request.getParameter("task.pageSize") - val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") + val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) + val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) + val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) + val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 256b726fa7ee..a28daf7f9045 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -42,15 +42,17 @@ private[ui] class StageTableBase( isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { - val allParameters = request.getParameterMap().asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) - val parameterStagePage = request.getParameter(stageTag + ".page") - val parameterStageSortColumn = request.getParameter(stageTag + ".sort") - val parameterStageSortDesc = request.getParameter(stageTag + ".desc") - val parameterStagePageSize = request.getParameter(stageTag + ".pageSize") - val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize") + val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page")) + val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort")) + val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc")) + val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize")) + val parameterStagePrevPageSize = + UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize")) val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => @@ -512,4 +514,3 @@ private[ui] class StageDataSource( } } } - diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 181465bdf960..799d76962639 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { @@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val stageId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { sc.foreach(_.cancelStage(id, "killed via the Web UI")) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index ac1a74ad8029..8bedd071a2c1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -20,6 +20,8 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} @@ -112,9 +114,9 @@ private[spark] object UIData { /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - class TaskUIData private( - private var _taskInfo: TaskInfo, - private var _metrics: Option[TaskMetricsUIData]) { + class TaskUIData private(private var _taskInfo: TaskInfo) { + + private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY) var errorMessage: Option[String] = None @@ -127,7 +129,7 @@ private[spark] object UIData { } def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { - _metrics = TaskUIData.toTaskMetricsUIData(metrics) + _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics) } def taskDuration: Option[Long] = { @@ -140,28 +142,16 @@ private[spark] object UIData { } object TaskUIData { - def apply(taskInfo: TaskInfo, metrics: Option[TaskMetrics]): TaskUIData = { - new TaskUIData(dropInternalAndSQLAccumulables(taskInfo), toTaskMetricsUIData(metrics)) + + private val stringInterner = Interners.newWeakInterner[String]() + + /** String interning to reduce the memory usage. */ + private def weakIntern(s: String): String = { + stringInterner.intern(s) } - private def toTaskMetricsUIData(metrics: Option[TaskMetrics]): Option[TaskMetricsUIData] = { - metrics.map { m => - TaskMetricsUIData( - executorDeserializeTime = m.executorDeserializeTime, - executorDeserializeCpuTime = m.executorDeserializeCpuTime, - executorRunTime = m.executorRunTime, - executorCpuTime = m.executorCpuTime, - resultSize = m.resultSize, - jvmGCTime = m.jvmGCTime, - resultSerializationTime = m.resultSerializationTime, - memoryBytesSpilled = m.memoryBytesSpilled, - diskBytesSpilled = m.diskBytesSpilled, - peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics), - outputMetrics = OutputMetricsUIData(m.outputMetrics), - shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), - shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) - } + def apply(taskInfo: TaskInfo): TaskUIData = { + new TaskUIData(dropInternalAndSQLAccumulables(taskInfo)) } /** @@ -174,8 +164,8 @@ private[spark] object UIData { index = taskInfo.index, attemptNumber = taskInfo.attemptNumber, launchTime = taskInfo.launchTime, - executorId = taskInfo.executorId, - host = taskInfo.host, + executorId = weakIntern(taskInfo.executorId), + host = weakIntern(taskInfo.host), taskLocality = taskInfo.taskLocality, speculative = taskInfo.speculative ) @@ -206,6 +196,28 @@ private[spark] object UIData { shuffleReadMetrics: ShuffleReadMetricsUIData, shuffleWriteMetrics: ShuffleWriteMetricsUIData) + object TaskMetricsUIData { + def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = { + TaskMetricsUIData( + executorDeserializeTime = m.executorDeserializeTime, + executorDeserializeCpuTime = m.executorDeserializeCpuTime, + executorRunTime = m.executorRunTime, + executorCpuTime = m.executorCpuTime, + resultSize = m.resultSize, + jvmGCTime = m.jvmGCTime, + resultSerializationTime = m.resultSerializationTime, + memoryBytesSpilled = m.memoryBytesSpilled, + diskBytesSpilled = m.diskBytesSpilled, + peakExecutionMemory = m.peakExecutionMemory, + inputMetrics = InputMetricsUIData(m.inputMetrics), + outputMetrics = OutputMetricsUIData(m.outputMetrics), + shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), + shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) + } + + val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty) + } + case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) object InputMetricsUIData { def apply(metrics: InputMetrics): InputMetricsUIData = { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index a1a0c729b924..317e0aa5ea25 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterBlockPage = request.getParameter("block.page") - val parameterBlockSortColumn = request.getParameter("block.sort") - val parameterBlockSortDesc = request.getParameter("block.desc") - val parameterBlockPageSize = request.getParameter("block.pageSize") - val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") + val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page")) + val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort")) + val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc")) + val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize")) + val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize")) val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be..148efb134e14 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 7479de55140e..603c23abb689 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -68,7 +68,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { private def assertMetadataNotNull(): Unit = { if (metadata == null) { - throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.") + throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.") } } @@ -85,7 +85,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { */ final def name: Option[String] = { assertMetadataNotNull() - metadata.name + + if (atDriverSide) { + metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name)) + } else { + metadata.name + } } /** @@ -161,7 +166,17 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX) + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + // For non-internal accumulators, we still need to send the name because users may need to + // access the accumulator name at executor side, or they may keep the accumulators sent from + // executors and access the name when the registered accumulator is already garbage + // collected(e.g. SQLMetrics). + copyAcc.metadata = metadata + } copyAcc } else { this @@ -250,7 +265,7 @@ private[spark] object AccumulatorContext { // Since we are storing weak references, we must check whether the underlying data is valid. val acc = ref.get if (acc eq null) { - throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id") + throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") } acc } @@ -263,16 +278,6 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4..ce06e18879a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 1aa4456ed01b..81aaf79db0c1 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -206,4 +206,25 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitresult + + // scalastyle:off awaitready + /** + * Preferred alternative to `Await.ready()`. + * + * @see [[awaitResult]] + */ + @throws(classOf[SparkException]) + def awaitReady[T](awaitable: Awaitable[T], atMost: Duration): awaitable.type = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.ready(atMost)(awaitPermission) + } catch { + // TimeoutException is thrown in the current thread, so not need to warp the exception. + case NonFatal(t) if !t.isInstanceOf[TimeoutException] => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } + // scalastyle:on awaitready } diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e2..27922b31949b 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 943dde072327..67497bbba150 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,7 +22,7 @@ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInf import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.Channels +import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import java.util.{Locale, Properties, Random, UUID} @@ -60,7 +60,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.util.logging.RollingFileAppender /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -319,41 +318,22 @@ private[spark] object Utils extends Logging { * copying is disabled by default unless explicitly set transferToEnabled as true, * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false, - transferToEnabled: Boolean = false): Long = - { - var count = 0L + def copyStream( + in: InputStream, + out: OutputStream, + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { tryWithSafeFinally { if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. val inChannel = in.asInstanceOf[FileInputStream].getChannel() val outChannel = out.asInstanceOf[FileOutputStream].getChannel() - val initialPos = outChannel.position() val size = inChannel.size() - - // In case transferTo method transferred less data than we have required. - while (count < size) { - count += inChannel.transferTo(count, size - count, outChannel) - } - - // Check the position after transferTo loop to see if it is in the right position and - // give user information if not. - // Position will not be increased to the expected length after calling transferTo in - // kernel version 2.6.32, this issue can be seen in - // https://bugs.openjdk.java.net/browse/JDK-7052359 - // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = outChannel.position() - assert(finalPos == initialPos + size, - s""" - |Current position $finalPos do not equal to expected position ${initialPos + size} - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + copyFileStreamNIO(inChannel, outChannel, 0, size) + size } else { + var count = 0L val buf = new Array[Byte](8192) var n = 0 while (n != -1) { @@ -363,8 +343,8 @@ private[spark] object Utils extends Logging { count += n } } + count } - count } { if (closeStreams) { try { @@ -376,6 +356,37 @@ private[spark] object Utils extends Logging { } } + def copyFileStreamNIO( + input: FileChannel, + output: FileChannel, + startPosition: Long, + bytesToCopy: Long): Unit = { + val initialPos = output.position() + var count = 0L + // In case transferTo method transferred less data than we have required. + while (count < bytesToCopy) { + count += input.transferTo(count + startPosition, bytesToCopy - count, output) + } + assert(count == bytesToCopy, + s"request to copy $bytesToCopy bytes, but actually copied $count bytes.") + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + val finalPos = output.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } + /** * Construct a URI container information used for authentication. * This also sets the default authenticator to properly negotiation the @@ -740,7 +751,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { @@ -2606,10 +2621,24 @@ private[spark] object Utils extends Logging { } private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { - kvs.map { kv => - redactionPattern.findFirstIn(kv._1) - .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } - .getOrElse(kv) + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) } } diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab68..51feccfb8342 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index ddbcb2d19dcb..3990ee1ec326 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -210,7 +210,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(ref.get.isEmpty) // Getting a garbage collected accum should throw error - intercept[IllegalAccessError] { + intercept[IllegalStateException] { AccumulatorContext.get(accId) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46..ee70a3399efe 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d33..71bedda5ac89 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{any, isA} +import org.mockito.Matchers.any import org.mockito.Mockito._ import org.apache.spark.broadcast.BroadcastManager diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 735f4454e299..979270a527a6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.concurrent.duration._ -import scala.concurrent.Await import com.google.common.io.Files import org.apache.hadoop.conf.Configuration @@ -35,7 +34,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -301,13 +300,13 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) sc.addJar(tmpJar.getAbsolutePath) - // Invaid jar path will only print the error log, will not add to file server. + // Invalid jar path will only print the error log, will not add to file server. sc.addJar("dummy.jar") sc.addJar("") sc.addJar(tmpDir.getAbsolutePath) - sc.listJars().size should be (1) - sc.listJars().head should include (tmpJar.getName) + assert(sc.listJars().size == 1) + assert(sc.listJars().head.contains(tmpJar.getName)) } test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { @@ -315,7 +314,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) sc.cancelJobGroup("nonExistGroupId") - Await.ready(future, Duration(2, TimeUnit.SECONDS)) + ThreadUtils.awaitReady(future, Duration(2, TimeUnit.SECONDS)) // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause // SparkContext to shutdown, so the following assertion will fail. @@ -540,10 +539,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - // Launches one task that will run forever. Once the SparkListener detects the task has + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has // started, kill and re-schedule it. The second run of the task will complete immediately. // If this test times out, then the first version of the task wasn't killed successfully. - test("Killing tasks") { + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) SparkContextSuite.isTaskStarted = false @@ -572,13 +585,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - try { - Thread.sleep(9999999) - } catch { - case t: Throwable => - // SPARK-20217 should not fail stage if task throws non-interrupted exception - throw new RuntimeException("killed") - } + blockFn } // second attempt succeeds immediately } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index f50cb38311db..42b8cde65039 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -243,16 +243,22 @@ private[deploy] object IvyTestUtils { withManifest: Option[Manifest] = None): File = { val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) - val manifest = withManifest.getOrElse { - val mani = new Manifest() + val manifest: Manifest = withManifest.getOrElse { if (withR) { + val mani = new Manifest() val attr = mani.getMainAttributes attr.put(Name.MANIFEST_VERSION, "1.0") attr.put(new Name("Spark-HasRPackage"), "true") + mani + } else { + null } - mani } - val jarStream = new JarOutputStream(jarFileStream, manifest) + val jarStream = if (manifest != null) { + new JarOutputStream(jarFileStream, manifest) + } else { + new JarOutputStream(jarFileStream) + } for (file <- files) { val jarEntry = new JarEntry(file._1) diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 005587051b6a..5e0bf6d438dc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -133,6 +133,16 @@ class RPackageUtilsSuite } } + test("jars without manifest return false") { + IvyTestUtils.withRepository(main, None, None) { repo => + val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil, + useIvyLayout = false, withR = false, None) + val jarFile = new JarFile(jar) + assert(jarFile.getManifest == null, "jar file should have null manifest") + assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false") + } + } + test("SparkR zipping works properly") { val tempDir = Files.createTempDir() Utils.tryWithSafeFinally { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 000000000000..ab24a76e20a3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c2ec01a03d0..6e9721c45931 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,11 +18,16 @@ package org.apache.spark.deploy import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,6 +39,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -404,6 +410,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) @@ -501,7 +538,7 @@ class SparkSubmitSuite test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars - val files = "hdfs:/file1,file2" // --files + val files = "local:/file1,file2" // --files val archives = "file:/archive1,archive2" // --archives val pyFiles = "py-file1,py-file2" // --py-files @@ -553,7 +590,7 @@ class SparkSubmitSuite test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars - val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles @@ -671,6 +708,87 @@ class SparkSubmitSuite } // scalastyle:on println + private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { + if (sourcePath == outputPath) { + return + } + + val sourceUri = new URI(sourcePath) + val outputUri = new URI(outputPath) + assert(outputUri.getScheme === "file") + + // The path and filename are preserved. + assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(FileUtils.readFileToString(new File(outputUri.getPath)) === + FileUtils.readFileToString(new File(sourceUri.getPath))) + } + + private def deleteTempOutputFile(outputPath: String): Unit = { + val outputFile = new File(new URI(outputPath).getPath) + if (outputFile.exists) { + outputFile.delete() + } + } + + test("downloadFile - invalid url") { + intercept[IOException] { + SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + } + } + + test("downloadFile - file doesn't exist") { + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + intercept[FileNotFoundException] { + SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + } + } + + test("downloadFile does not download local file") { + // empty path is considered as local file. + assert(SparkSubmit.downloadFile("", new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + } + + test("download one file to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePath = s"s3a://${jarFile.getAbsolutePath}" + val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + + test("download list of files to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") + val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + + assert(outputPaths.length === sourcePaths.length) + sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -773,3 +891,10 @@ object UserClasspathFirstTest { } } } + +class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { + override def copyToLocalFile(src: Path, dst: Path): Unit = { + // Ignore the scheme for testing. + super.copyToLocalFile(new Path(src.toUri.getPath), dst) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 9839dcf8535d..bf7480d79f8a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ec580a44b8e7..456158d41b93 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 764156c3edc4..95acb9a54440 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 2127da48ece4..539264652d7d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -34,7 +34,7 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv} class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -447,8 +447,15 @@ class MasterSuite extends SparkFunSuite } }) - master.self.send( - RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost", 9999))) val executors = (0 until 3).map { i => new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) } @@ -459,4 +466,37 @@ class MasterSuite extends SparkFunSuite assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) } } + + test("SPARK-20529: Master should reply the address received from worker") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + @volatile var receivedMasterAddress: RpcAddress = null + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(_, _, masterAddress) => + receivedMasterAddress = masterAddress + } + }) + + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000))) + + eventually(timeout(10.seconds)) { + assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f47e574b4fc4..efcad140350b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 5d522189a0c2..6f4203da1d86 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } assert(bytesRead >= tmpFile.length()) } + + test("input metrics with old Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } } /** diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index fe8955840d72..474e30144f62 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -22,7 +22,7 @@ import java.nio._ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit -import scala.concurrent.{Await, Promise} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -36,6 +36,7 @@ import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, ShuffleBlockId} +import org.apache.spark.util.ThreadUtils class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { @@ -164,9 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }) + }, null) - Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) + ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c8..8d06f5468f4f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext { var tempDir: File = _ @@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(totalPartitionCount == 10) } + test("SPARK-18406: race between end-of-task and completion iterator read lock release") { + val rdd = sc.parallelize(1 to 1000, 10) + rdd.cache() + + rdd.mapPartitions { iter => + ThreadUtils.runInNewThread("TestThread") { + // Iterate to the end of the input iterator, to cause the CompletionIterator completion to + // fire outside of the task's main thread. + while (iter.hasNext) { + iter.next() + } + iter + } + }.collect() + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a..7f20206202cb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 759d52fca5ce..3ec37f674c77 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import scala.util.Random +import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.BlockManagerId @@ -128,4 +132,26 @@ class MapStatusSuite extends SparkFunSuite { assert(size1 === size2) assert(!success) } + + test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " + + "underestimated.") { + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + // Value of element in sizes is equal to the corresponding index. + val sizes = (0L to 2000L).toArray + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val arrayStream = new ByteArrayOutputStream(102400) + val objectOutputStream = new ObjectOutputStream(arrayStream) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + objectOutputStream.writeObject(status1) + objectOutputStream.flush() + val array = arrayStream.toByteArray + val objectInput = new ObjectInputStream(new ByteArrayInputStream(array)) + val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus] + (1001 to 2000).foreach { + case part => assert(status2.getSizeForBlock(part) >= sizes(part)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 8300607ea888..37b08980db87 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration.{Duration, SECONDS} import scala.language.existentials import scala.reflect.ClassTag @@ -260,7 +260,7 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa */ def awaitJobTermination(jobFuture: Future[_], duration: Duration): Unit = { try { - Await.ready(jobFuture, duration) + ThreadUtils.awaitReady(jobFuture, duration) } catch { case te: TimeoutException if backendException.get() != null => val msg = raw""" diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8f576daa77d1..992d3396d203 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.addTaskCompletionListener(_ => throw new Exception("blah")) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), @@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("immediately call a completion listener if the context is completed") { var invocations = 0 val context = TaskContext.empty() - context.markTaskCompleted() + context.markTaskCompleted(None) context.addTaskCompletionListener(_ => invocations += 1) assert(invocations == 1) - context.markTaskCompleted() + context.markTaskCompleted(None) assert(invocations == 1) } @@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(lastError == error) assert(invocations == 1) } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener1")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe63..db14c9acfdce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1070,11 +1070,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) - when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - assert(manager.isZombie === true) - } - }) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala index 1bfb0c1547ec..82bd7c4ff660 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -31,7 +31,7 @@ class AllStagesResourceSuite extends SparkFunSuite { val tasks = new LinkedHashMap[Long, TaskUIData] taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => tasks(idx.toLong) = TaskUIData( - new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None) + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false)) } val stageUiData = new StageUIData() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 1b325801e27f..917db766f7f1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -152,7 +152,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { // one should acquire the write lock. The second thread should block until the winner of the // write race releases its lock. val winningFuture: Future[Boolean] = - Await.ready(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds) + ThreadUtils.awaitReady(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds) assert(winningFuture.value.get.get) val winningTID = blockInfoManager.get("block").get.writerTask assert(winningTID === 1 || winningTID === 2) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a8b960489983..9d7a8696818f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.File import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -1265,7 +1266,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], - listener: BlockFetchingListener): Unit = { + listener: BlockFetchingListener, + shuffleFiles: Array[File]): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index dfecd04c1b96..4000218e71a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bbfd6df3b699..7859b0bba2b4 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import scala.language.reflectiveCalls - import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd..6883eb211efd 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,22 +33,66 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + + private def assumeNonExistentAndNotCreatable(f: File): Unit = { + try { + assume(!f.exists() && !f.mkdirs()) + } finally { + Utils.deleteRecursively(f) + } + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + val f = new File("/NONEXISTENT_PATH") + assumeNonExistentAndNotCreatable(f) + val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) + + // This directory should not be created. + assert(!f.exists()) } test("SPARK_LOCAL_DIRS override also affects driver") { - // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + // Regression test for SPARK-2974 + val f = new File("/NONEXISTENT_PATH") + assumeNonExistentAndNotCreatable(f) + // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) .set("spark.local.dir", "/NONEXISTENT_PATH") assert(new File(Utils.getLocalDir(conf)).exists()) + + // This directory should not be created. + assert(!f.exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + val f1 = new File(path1) + val f2 = new File(path2) + assumeNonExistentAndNotCreatable(f1) + assumeNonExistentAndNotCreatable(f2) + + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + + // These directories should not be created. + assert(!f1.exists()) + assert(!f2.exists()) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index 3050f9a25023..535105379963 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e56e440380a5..559b3faab8fd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, InputStream, IOException} +import java.util.UUID import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -35,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -44,7 +46,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -106,6 +109,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // 3 local blocks fetched in initialization @@ -134,7 +138,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -153,7 +157,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -181,6 +186,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -192,7 +198,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state @@ -218,7 +224,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -246,6 +253,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -281,7 +289,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -309,6 +318,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -318,7 +328,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -359,7 +370,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { @@ -387,6 +399,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => new LimitedInputStream(in, 100), 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, false) // Continue only after the mock calls onBlockFetchFailure @@ -401,4 +414,65 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(id3 === ShuffleBlockId(0, 2, 0)) } + test("Blocks should be shuffled to disk when size of the request is above the" + + " threshold(maxReqSizeShuffleToMem).") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + val tmpDir = Utils.createTempDir() + doReturn{ + val blockId = TempLocalBlockId(UUID.randomUUID()) + (blockId, new File(tmpDir, blockId.name)) + }.when(diskBlockManager).createTempLocalBlock() + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } + }) + + def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks + // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. + new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + maxBytesInFlight = Int.MaxValue, + maxReqsInFlight = Int.MaxValue, + maxReqSizeShuffleToMem = 200, + detectCorrupt = true) + } + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + fetchShuffleBlock(blocksByAddress1) + // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch + // shuffle block to disk. + assert(shuffleFiles === null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + fetchShuffleBlock(blocksByAddress2) + // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch + // shuffle block to disk. + assert(shuffleFiles != null) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index c770fd5da76f..423daacc0f5a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite { assert(decoded2 === decodeURLParameter(decoded2)) } + test("SPARK-20393: Prevent newline characters in parameters.") { + val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b" + val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b" + + assert(stripEncoding === stripXSS(encoding)) + } + + test("SPARK-20393: Prevent script from parameters running on page.") { + val scriptAlert = """>"'>