diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 0e41cf182645..5af45d6fa798 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -7,4 +7,4 @@ (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) -Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. +Please review http://spark.apache.org/contributing.html before opening a pull request. diff --git a/.gitignore b/.gitignore index 39d17e1793f7..5634a434db0c 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1a8206abe383..8fdd5aa9e7df 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ ## Contributing to Spark *Before opening a pull request*, review the -[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? +- Is this a new feature that can stand alone as a [third party project](http://spark.apache.org/third-party-projects.html) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/LICENSE b/LICENSE index 7950dd6ceb6d..119ecbe15ab7 100644 --- a/LICENSE +++ b/LICENSE @@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) diff --git a/NOTICE b/NOTICE index 69b513ea3ba3..f4b64b5c3f47 100644 --- a/NOTICE +++ b/NOTICE @@ -421,9 +421,6 @@ Copyright (c) 2011, Terrence Parr. This product includes/uses ASM (http://asm.ow2.org/), Copyright (c) 2000-2007 INRIA, France Telecom. -This product includes/uses org.json (http://www.json.org/java/index.html), -Copyright (c) 2002 JSON.org - This product includes/uses JLine (http://jline.sourceforge.net/), Copyright (c) 2002-2006, Marc Prud'hommeaux . diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 000000000000..d6084c7a7cc9 --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `run-tests.sh` runs `check-cran.sh` (which runs `R CMD check`), it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. Also note that for CRAN checks for pdf vignettes to success, `qpdf` tool must be there (to install it, eg. `yum -q -y install qpdf`). + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4..4c40c5963db7 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,19 +46,19 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR -The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +The [instructions](http://spark.apache.org/contributing.html) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae93..1288e7fc9fb4 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -34,13 +34,30 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" -# Build the latest docs +# Install the package (this is required for code in vignettes to run when building it later) +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +71,32 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi + +# Install source package to get it to generate vignettes rds files, etc. +if [ -n "$CLEAN_INSTALL" ] +then + echo "Removing lib path and installing from source package" + LIB_DIR="$FWDIR/lib" + rm -rf $LIB_DIR + mkdir -p $LIB_DIR + "$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR + + # Zip the SparkR package so that it can be distributed to worker nodes on YARN + pushd $LIB_DIR > /dev/null + jar cfM "$LIB_DIR/sparkr.zip" SparkR + popd > /dev/null +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c3..84e6aa928cb0 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/install-dev.sh b/R/install-dev.sh index ada6303a722b..0f881208bcad 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -46,7 +46,7 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" # Generate Rd files if devtools is installed "$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index 544d203a6dce..f12f8c275a98 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -1,5 +1,8 @@ ^.*\.Rproj$ ^\.Rproj\.user$ ^\.lintr$ +^cran-comments\.md$ +^NEWS\.md$ +^README\.Rmd$ ^src-native$ ^html$ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0..2d461ca68920 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package +Version: 2.1.2 Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -10,17 +10,18 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), person("Felix", "Cheung", role = "aut", email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ -BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports +BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods Suggests: + knitr, + rmarkdown, testthat, e1071, survival -Description: The SparkR package provides an R frontend for Apache Spark. -License: Apache License (== 2.0) Collate: 'schema.R' 'generics.R' @@ -48,3 +49,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9cd6269f9a8f..6f96b969d149 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,9 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Imports from base R # Do not include stats:: "rpois", "runif" - causes error at runtime importFrom("methods", "setGeneric", "setMethod", "setOldClass") importFrom("methods", "is", "new", "signature", "show") importFrom("stats", "gaussian", "setNames") -importFrom("utils", "download.file", "object.size", "packageVersion", "untar") +importFrom("utils", "download.file", "object.size", "packageVersion", "tail", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -16,6 +33,7 @@ export("sparkR.stop") export("sparkR.session.stop") export("sparkR.conf") export("sparkR.version") +export("sparkR.uiWebUrl") export("print.jobj") export("sparkR.newJObject") @@ -45,7 +63,8 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", - "spark.randomForest") + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -61,6 +80,7 @@ exportMethods("arrange", "as.data.frame", "attach", "cache", + "coalesce", "collect", "colnames", "colnames<-", @@ -92,6 +112,7 @@ exportMethods("arrange", "freqItems", "gapply", "gapplyCollect", + "getNumPartitions", "group_by", "groupBy", "head", @@ -353,7 +374,9 @@ export("as.DataFrame", "read.ml", "print.summary.KSTest", "print.summary.RandomForestRegressionModel", - "print.summary.RandomForestClassificationModel") + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -380,6 +403,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea648..d0f097925a2c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -417,7 +417,7 @@ setMethod("coltypes", type <- PRIMITIVE_TYPES[[specialtype]] } } - type + type[[1]] }) # Find which types don't have mapping to R @@ -680,14 +680,53 @@ setMethod("storageLevel", storageLevelToString(callJMethod(x@sdf, "storageLevel")) }) +#' Coalesce +#' +#' Returns a new SparkDataFrame that has exactly \code{numPartitions} partitions. +#' This operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 +#' partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of +#' the current partitions. If a larger number of partitions is requested, it will stay at the +#' current number of partitions. +#' +#' However, if you're doing a drastic coalesce on a SparkDataFrame, e.g. to numPartitions = 1, +#' this may result in your computation taking place on fewer nodes than +#' you like (e.g. one node in the case of numPartitions = 1). To avoid this, +#' call \code{repartition}. This will add a shuffle step, but means the +#' current upstream partitions will be executed in parallel (per whatever +#' the current partitioning is). +#' +#' @param numPartitions the number of partitions to use. +#' +#' @family SparkDataFrame functions +#' @rdname coalesce +#' @name coalesce +#' @aliases coalesce,SparkDataFrame-method +#' @seealso \link{repartition} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- coalesce(df, 1L) +#'} +#' @note coalesce(SparkDataFrame) since 2.1.1 +setMethod("coalesce", + signature(x = "SparkDataFrame"), + function(x, numPartitions) { + stopifnot(is.numeric(numPartitions)) + sdf <- callJMethod(x@sdf, "coalesce", numToInt(numPartitions)) + dataFrame(sdf) + }) + #' Repartition #' #' The following options for repartition are possible: #' \itemize{ -#' \item{1.} {Return a new SparkDataFrame partitioned by +#' \item{1.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame hash partitioned by #' the given columns into \code{numPartitions}.} -#' \item{2.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} -#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s), +#' \item{3.} {Return a new SparkDataFrame hash partitioned by the given column(s), #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} #' @param x a SparkDataFrame. @@ -699,6 +738,7 @@ setMethod("storageLevel", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method +#' @seealso \link{coalesce} #' @export #' @examples #'\dontrun{ @@ -937,6 +977,8 @@ setMethod("unique", #' Sample #' #' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not @@ -1130,6 +1172,7 @@ setMethod("collect", if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") + class(vec) <- PRIMITIVE_TYPES[[colType]] df[[colIndex]] <- vec } else { df[[colIndex]] <- col @@ -1709,6 +1752,23 @@ getColumn <- function(x, c) { column(callJMethod(x@sdf, "col", c)) } +setColumn <- function(x, c, value) { + if (class(value) != "Column" && !is.null(value)) { + if (isAtomicLengthOne(value)) { + value <- lit(value) + } else { + stop("value must be a Column, literal value as atomic in length of 1, or NULL") + } + } + + if (is.null(value)) { + nx <- drop(x, c) + } else { + nx <- withColumn(x, c, value) + } + nx +} + #' @param name name of a Column (without being wrapped by \code{""}). #' @rdname select #' @name $ @@ -1719,20 +1779,15 @@ setMethod("$", signature(x = "SparkDataFrame"), getColumn(x, name) }) -#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @rdname select #' @name $<- #' @aliases $<-,SparkDataFrame-method #' @note $<- since 1.4.0 setMethod("$<-", signature(x = "SparkDataFrame"), function(x, name, value) { - stopifnot(class(value) == "Column" || is.null(value)) - - if (is.null(value)) { - nx <- drop(x, name) - } else { - nx <- withColumn(x, name, value) - } + nx <- setColumn(x, name, value) x@sdf <- nx@sdf x }) @@ -1745,6 +1800,10 @@ setClassUnion("numericOrcharacter", c("numeric", "character")) #' @note [[ since 1.4.0 setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] @@ -1752,6 +1811,25 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), getColumn(x, i) }) +#' @rdname subset +#' @name [[<- +#' @aliases [[<-,SparkDataFrame,numericOrcharacter-method +#' @note [[<- since 2.1.1 +setMethod("[[<-", signature(x = "SparkDataFrame", i = "numericOrcharacter"), + function(x, i, value) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + nx <- setColumn(x, i, value) + x@sdf <- nx@sdf + x + }) + #' @rdname subset #' @name [ #' @aliases [,SparkDataFrame-method @@ -1796,14 +1874,19 @@ setMethod("[", signature(x = "SparkDataFrame"), #' Return subsets of SparkDataFrame according to given conditions #' @param x a SparkDataFrame. #' @param i,subset (Optional) a logical expression to filter on rows. +#' For extract operator [[ and replacement operator [[<-, the indexing parameter for +#' a single Column. #' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method +#' @seealso \link{withColumn} #' @rdname subset #' @name subset #' @family subsetting functions @@ -1821,6 +1904,10 @@ setMethod("[", signature(x = "SparkDataFrame"), #' subset(df, df$age %in% c(19, 30), 1:2) #' subset(df, df$age %in% c(19), select = c(1,2)) #' subset(df, select = c(1,2)) +#' # Columns can be selected and set +#' df[["age"]] <- 23 +#' df[[1]] <- df$age +#' df[[2]] <- NULL # drop column #' } #' @note subset since 1.5.0 setMethod("subset", signature(x = "SparkDataFrame"), @@ -1939,13 +2026,13 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression. +#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions -#' @aliases withColumn,SparkDataFrame,character,Column-method +#' @aliases withColumn,SparkDataFrame,character-method #' @rdname withColumn #' @name withColumn -#' @seealso \link{rename} \link{mutate} +#' @seealso \link{rename} \link{mutate} \link{subset} #' @export #' @examples #'\dontrun{ @@ -1955,11 +2042,20 @@ setMethod("selectExpr", #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' # Replace an existing column #' newDF2 <- withColumn(newDF, "newCol", newDF$col1) +#' newDF3 <- withColumn(newDF, "newCol", 42) +#' # Use extract operator to set an existing or new column +#' df[["age"]] <- 23 +#' df[[2]] <- df$col1 +#' df[[2]] <- NULL # drop column #' } #' @note withColumn since 1.4.0 setMethod("withColumn", - signature(x = "SparkDataFrame", colName = "character", col = "Column"), + signature(x = "SparkDataFrame", colName = "character"), function(x, colName, col) { + if (class(col) != "Column") { + if (!isAtomicLengthOne(col)) stop("Literal value must be atomic in length of 1") + col <- lit(col) + } sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc) dataFrame(sdf) }) @@ -2305,9 +2401,9 @@ setMethod("dropDuplicates", #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, the default, inner join is attempted and an error is #' thrown if it would be a Cartesian Product. For Cartesian join, use crossJoin instead. -#' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', -#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". +#' @param joinType The type of join to perform, default 'inner'. +#' Must be one of: 'inner', 'cross', 'outer', 'full', 'full_outer', +#' 'left', 'left_outer', 'right', 'right_outer', 'left_semi', or 'left_anti'. #' @return A SparkDataFrame containing the result of the join operation. #' @family SparkDataFrame functions #' @aliases join,SparkDataFrame,SparkDataFrame-method @@ -2336,15 +2432,18 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "full", "fullouter", - "leftouter", "left_outer", "left", - "rightouter", "right_outer", "right", "leftsemi")) { + if (joinType %in% c("inner", "cross", + "outer", "full", "fullouter", "full_outer", + "left", "leftouter", "left_outer", + "right", "rightouter", "right_outer", + "left_semi", "leftsemi", "left_anti", "leftanti")) { joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', - 'rightouter', 'right_outer', 'right', 'leftsemi'") + "'inner', 'cross', 'outer', 'full', 'full_outer',", + "'left', 'left_outer', 'right', 'right_outer',", + "'left_semi', or 'left_anti'.") } } } @@ -2539,7 +2638,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2582,7 +2682,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' #' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x a SparkDataFrame. #' @param ... additional SparkDataFrame(s). @@ -3381,3 +3482,26 @@ setMethod("randomSplit", } sapply(sdfs, dataFrame) }) + +#' getNumPartitions +#' +#' Return the number of partitions +#' +#' @param x A SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases getNumPartitions,SparkDataFrame-method +#' @rdname getNumPartitions +#' @name getNumPartitions +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createDataFrame(cars, numPartitions = 2) +#' getNumPartitions(df) +#' } +#' @note getNumPartitions since 2.1.1 +setMethod("getNumPartitions", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(callJMethod(x@sdf, "rdd"), "getNumPartitions") + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 0f1162fec1df..5667b9d78882 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -313,7 +313,7 @@ setMethod("checkpoint", #' @rdname getNumPartitions #' @aliases getNumPartitions,RDD-method #' @noRd -setMethod("getNumPartitions", +setMethod("getNumPartitionsRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "getNumPartitions") @@ -329,7 +329,7 @@ setMethod("numPartitions", signature(x = "RDD"), function(x) { .Deprecated("getNumPartitions") - getNumPartitions(x) + getNumPartitionsRDD(x) }) #' Collect elements of an RDD @@ -460,7 +460,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collectRDD(reduceByKey(ones, `+`, getNumPartitions(x))) + collectRDD(reduceByKey(ones, `+`, getNumPartitionsRDD(x))) }) #' Apply a function to all elements @@ -780,7 +780,7 @@ setMethod("takeRDD", resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- getNumPartitions(x) + numPartitions <- getNumPartitionsRDD(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -846,7 +846,7 @@ setMethod("firstRDD", #' @noRd setMethod("distinctRDD", signature(x = "RDD"), - function(x, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitionsRDD(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -1028,7 +1028,7 @@ setMethod("repartitionRDD", signature(x = "RDD"), function(x, numPartitions) { if (!is.null(numPartitions) && is.numeric(numPartitions)) { - coalesce(x, numPartitions, TRUE) + coalesceRDD(x, numPartitions, TRUE) } else { stop("Please, specify the number of partitions") } @@ -1049,11 +1049,11 @@ setMethod("repartitionRDD", #' @rdname coalesce #' @aliases coalesce,RDD #' @noRd -setMethod("coalesce", +setMethod("coalesceRDD", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitionsRDD(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1143,7 +1143,7 @@ setMethod("saveAsTextFile", #' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitionsRDD(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1175,7 +1175,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- getNumPartitions(newRdd) + numPartitions <- getNumPartitionsRDD(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1407,7 +1407,7 @@ setMethod("setName", setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- getNumPartitions(x) + n <- getNumPartitionsRDD(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1450,7 +1450,7 @@ setMethod("zipWithUniqueId", setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- getNumPartitions(x) + n <- getNumPartitionsRDD(x) if (n > 1) { nums <- collectRDD(lapplyPartition(x, function(part) { @@ -1566,8 +1566,8 @@ setMethod("unionRDD", setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- getNumPartitions(x) - n2 <- getNumPartitions(other) + n1 <- getNumPartitionsRDD(x) + n2 <- getNumPartitionsRDD(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1637,7 +1637,7 @@ setMethod("cartesian", #' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) @@ -1671,7 +1671,7 @@ setMethod("subtract", #' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1714,7 +1714,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, getNumPartitions) + nPart <- sapply(rrdds, getNumPartitionsRDD) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 38d83c6e5c52..e771a057e244 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -184,8 +184,11 @@ getDefaultSqlSource <- function() { #' #' Converts R data.frame or list into SparkDataFrame. #' -#' @param data an RDD or list or data.frame. +#' @param data a list or data.frame. #' @param schema a list of column names or named list (StructType), optional. +#' @param samplingRatio Currently not used. +#' @param numPartitions the number of partitions of the SparkDataFrame. Defaults to 1, this is +#' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame #' @export @@ -195,12 +198,14 @@ getDefaultSqlSource <- function() { #' df1 <- as.DataFrame(iris) #' df2 <- as.DataFrame(list(3,4,5,6)) #' df3 <- createDataFrame(iris) +#' df4 <- createDataFrame(cars, numPartitions = 2) #' } #' @name createDataFrame #' @method createDataFrame default #' @note createDataFrame since 1.4.0 # TODO(davies): support sampling and infer type from NA -createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { +createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, + numPartitions = NULL) { sparkSession <- getSparkSession() if (is.data.frame(data)) { @@ -233,7 +238,11 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) - rdd <- parallelize(sc, data) + if (!is.null(numPartitions)) { + rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions)) + } else { + rdd <- parallelize(sc, data, numSlices = 1) + } } else if (inherits(data, "RDD")) { rdd <- data } else { @@ -283,14 +292,13 @@ createDataFrame <- function(x, ...) { dispatchFunc("createDataFrame(data, schema = NULL)", x, ...) } -#' @param samplingRatio Currently not used. #' @rdname createDataFrame #' @aliases createDataFrame #' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 -as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - createDataFrame(data, schema) +as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { + createDataFrame(data, schema, samplingRatio, numPartitions) } #' @param ... additional argument(s). @@ -634,7 +642,7 @@ tableNames <- function(x, ...) { cacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "cacheTable", tableName) + invisible(callJMethod(catalog, "cacheTable", tableName)) } cacheTable <- function(x, ...) { @@ -663,7 +671,7 @@ cacheTable <- function(x, ...) { uncacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "uncacheTable", tableName) + invisible(callJMethod(catalog, "uncacheTable", tableName)) } uncacheTable <- function(x, ...) { @@ -686,7 +694,7 @@ uncacheTable <- function(x, ...) { clearCache.default <- function() { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "clearCache") + invisible(callJMethod(catalog, "clearCache")) } clearCache <- function() { @@ -730,6 +738,7 @@ dropTempTable <- function(x, ...) { #' If the view has been cached before, then it will also be uncached. #' #' @param viewName the name of the view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView #' @export diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 438d77a388f0..634bdcb52363 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,10 +87,20 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' -#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function -#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not #' larger than that limit, number of slices may be increased. #' +#' In 2.2.0 we are changing how the numSlices are used/computed to handle +#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices. +#' This change affects both createDataFrame and spark.lapply. +#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has +#' always been kept at the default of 1. In the case the object is large, we are explicitly setting +#' the parallism to numSlices (which is still 1). +#' +#' Specifically, we are changing to split positions to match the calculation in positions() of +#' ParallelCollectionRDD in Spark. +#' #' @param sc SparkContext to use #' @param coll collection to parallelize #' @param numSlices number of partitions to create in the RDD @@ -107,6 +117,8 @@ parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives # TODO: support matrix, data frame, etc + + # Note, for data.frame, createDataFrame turns it into a list before it calls here. # nolint start # suppress lintr warning: Place a space before left parenthesis, except in a function call. if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { @@ -128,12 +140,29 @@ parallelize <- function(sc, coll, numSlices = 1) { objectSize <- object.size(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSlices > length(coll)) - numSlices <- length(coll) + numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) + if (numSerializedSlices > length(coll)) + numSerializedSlices <- length(coll) + + # Generate the slice ids to put each row + # For instance, for numSerializedSlices of 22, length of 50 + # [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22 + # [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47 + # Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced. + # We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD + splits <- if (numSerializedSlices > 0) { + unlist(lapply(0: (numSerializedSlices - 1), function(x) { + # nolint start + start <- trunc((x * length(coll)) / numSerializedSlices) + end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + # nolint end + rep(start, end - start) + })) + } else { + 1 + } - sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, splits) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws @@ -301,7 +330,13 @@ spark.addFile <- function(path, recursive = FALSE) { #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 spark.getSparkFilesRootDirectory <- function() { - callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + } else { + # Running on worker. + Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR") + } } #' Get the absolute path of a file added through spark.addFile. @@ -316,7 +351,13 @@ spark.getSparkFilesRootDirectory <- function() { #'} #' @note spark.getSparkFiles since 2.1.0 spark.getSparkFiles <- function(fileName) { - callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + } else { + # Running on worker. + file.path(spark.getSparkFilesRootDirectory(), as.character(fileName)) + } } #' Run a function over a list of elements, distributing the computations with Spark @@ -379,5 +420,5 @@ spark.lapply <- function(list, func) { #' @note setLogLevel since 2.0.0 setLogLevel <- function(level) { sc <- getSparkContext() - callJMethod(sc, "setLogLevel", level) + invisible(callJMethod(sc, "setLogLevel", level)) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d4..2992baee6809 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -286,6 +286,28 @@ setMethod("ceil", column(jc) }) +#' Returns the first column that is not NA +#' +#' Returns the first column that is not NA, or NA if all inputs are. +#' +#' @rdname coalesce +#' @name coalesce +#' @family normal_funcs +#' @export +#' @aliases coalesce,Column-method +#' @examples \dontrun{coalesce(df$c, df$d, df$e)} +#' @note coalesce(Column) since 2.1.1 +setMethod("coalesce", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "coalesce", jcols) + column(jc) + }) + #' Though scala functions has "col" function, we don't expose it in SparkR #' because we don't want to conflict with the "col" function in the R base #' package and we also have "column" function exported which is an alias of "col". @@ -297,7 +319,7 @@ col <- function(x) { #' Returns a Column based on the given column name #' #' Returns a Column based on the given column name. -# +#' #' @param x Character column name. #' #' @rdname column @@ -305,7 +327,7 @@ col <- function(x) { #' @family normal_funcs #' @export #' @aliases column,character-method -#' @examples \dontrun{column(df)} +#' @examples \dontrun{column("name")} #' @note column since 1.6.0 setMethod("column", signature(x = "character"), @@ -1485,7 +1507,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -2296,7 +2318,7 @@ setMethod("n", signature(x = "Column"), #' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' #' @param y Column to compute on. @@ -2317,7 +2339,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. @@ -2340,7 +2363,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2391,7 +2414,8 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use @@ -2539,7 +2563,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. @@ -2777,7 +2801,8 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. @@ -2823,7 +2848,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -2852,7 +2878,8 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -3145,7 +3172,8 @@ setMethod("cume_dist", #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the \code{DENSE_RANK} function in SQL. #' @@ -3316,10 +3344,11 @@ setMethod("percent_rank", #' #' Window function: returns the rank of rows within a window partition. #' -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the RANK function in SQL. #' @@ -3442,8 +3471,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0271b26a10a9..f018ec9e46a6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -28,7 +28,7 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition # @export -setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) +setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods # @export @@ -138,9 +138,9 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname getNumPartitions +# @rdname getNumPartitionsRDD # @export -setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) +setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions # @export @@ -406,6 +406,13 @@ setGeneric("attach") #' @export setGeneric("cache", function(x) { standardGeneric("cache") }) +#' @rdname coalesce +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally +#' provided. +#' @export +setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) + #' @rdname collect #' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) @@ -492,6 +499,10 @@ setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -1343,6 +1354,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + #' @rdname spark.glm #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) @@ -1369,7 +1384,7 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname spark.mlp #' @export -setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) +setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes #' @export diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 69b0a523b84e..4ca7aa664e02 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -21,9 +21,9 @@ #' Download and Install Apache Spark to a Local Directory #' #' \code{install.spark} downloads and installs Spark to a local directory if -#' it is not found. The Spark version we use is the same as the SparkR version. -#' Users can specify a desired Hadoop version, the remote mirror site, and -#' the directory where the package is installed locally. +#' it is not found. If SPARK_HOME is set in the environment, and that directory is found, that is +#' returned. The Spark version we use is the same as the SparkR version. Users can specify a desired +#' Hadoop version, the remote mirror site, and the directory where the package is installed locally. #' #' The full url of remote file is inferred from \code{mirrorUrl} and \code{hadoopVersion}. #' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder @@ -50,11 +50,11 @@ #' \itemize{ #' \item Mac OS X: \file{~/Library/Caches/spark} #' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark} -#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. +#' \item Windows: \file{\%LOCALAPPDATA\%\\Apache\\Spark\\Cache}. #' } #' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir #' and force re-install Spark (in case the local directory or file is corrupted) -#' @return \code{install.spark} returns the local directory where Spark is found or installed +#' @return the (invisible) local directory where Spark is found or installed #' @rdname install.spark #' @name install.spark #' @aliases install.spark @@ -68,6 +68,16 @@ #' \href{http://spark.apache.org/downloads.html}{Apache Spark} install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, localDir = NULL, overwrite = FALSE) { + sparkHome <- Sys.getenv("SPARK_HOME") + if (isSparkRShell()) { + stopifnot(nchar(sparkHome) > 0) + message("Spark is already running in sparkR shell.") + return(invisible(sparkHome)) + } else if (!is.na(file.info(sparkHome)$isdir)) { + message("Spark package found in SPARK_HOME: ", sparkHome) + return(invisible(sparkHome)) + } + version <- paste0("spark-", packageVersion("SparkR")) hadoopVersion <- tolower(hadoopVersion) hadoopVersionName <- hadoopVersionName(hadoopVersion) @@ -79,19 +89,28 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, dir.create(localDir, recursive = TRUE) } - packageLocalDir <- file.path(localDir, packageName) - if (overwrite) { message(paste0("Overwrite = TRUE: download and overwrite the tar file", "and Spark package directory if they exist.")) } + releaseUrl <- Sys.getenv("SPARKR_RELEASE_DOWNLOAD_URL") + if (releaseUrl != "") { + packageName <- basenameSansExtFromUrl(releaseUrl) + } + + packageLocalDir <- file.path(localDir, packageName) + # can use dir.exists(packageLocalDir) under R 3.2.0 or later if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { - fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" - msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), - packageLocalDir) - message(msg) + if (releaseUrl != "") { + message(paste(packageName, "found, setting SPARK_HOME to", packageLocalDir)) + } else { + fmt <- "%s for Hadoop %s found, setting SPARK_HOME to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + } Sys.setenv(SPARK_HOME = packageLocalDir) return(invisible(packageLocalDir)) } else { @@ -104,14 +123,37 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, if (tarExists && !overwrite) { message("tar file found.") } else { - robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + if (releaseUrl != "") { + message("Downloading from alternate URL:\n- ", releaseUrl) + success <- downloadUrl(releaseUrl, packageLocalPath) + if (!success) { + unlink(packageLocalPath) + stop(paste0("Fetch failed from ", releaseUrl)) + } + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } } message(sprintf("Installing to %s", localDir)) - untar(tarfile = packageLocalPath, exdir = localDir) - if (!tarExists || overwrite) { + # There are two ways untar can fail - untar could stop() on errors like incomplete block on file + # or, tar command can return failure code + success <- tryCatch(untar(tarfile = packageLocalPath, exdir = localDir) == 0, + error = function(e) { + message(e) + message() + FALSE + }, + warning = function(w) { + # Treat warning as error, add an empty line with message() + message(w) + message() + FALSE + }) + if (!tarExists || overwrite || !success) { unlink(packageLocalPath) } + if (!success) stop("Extract archive failed.") message("DONE.") Sys.setenv(SPARK_HOME = packageLocalDir) message(paste("SPARK_HOME set to", packageLocalDir)) @@ -121,8 +163,7 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { # step 1: use user-provided url if (!is.null(mirrorUrl)) { - msg <- sprintf("Use user-provided mirror site: %s.", mirrorUrl) - message(msg) + message("Use user-provided mirror site: ", mirrorUrl) success <- directDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) if (success) { @@ -142,7 +183,7 @@ robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, pa packageName, packageLocalPath) if (success) return() } else { - message("Unable to find preferred mirror site.") + message("Unable to download from preferred mirror site: ", mirrorUrl) } # step 3: use backup option @@ -151,8 +192,11 @@ robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, pa success <- directDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) if (success) { - return(packageLocalPath) + return() } else { + # remove any partially downloaded file + unlink(packageLocalPath) + message("Unable to download from default mirror site: ", mirrorUrl) msg <- sprintf(paste("Unable to download Spark %s for Hadoop %s.", "Please check network connection, Hadoop version,", "or provide other mirror sites."), @@ -182,17 +226,25 @@ getPreferredMirror <- function(version, packageName) { } directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { - packageRemotePath <- paste0( - file.path(mirrorUrl, version, packageName), ".tgz") + packageRemotePath <- paste0(file.path(mirrorUrl, version, packageName), ".tgz") fmt <- "Downloading %s for Hadoop %s from:\n- %s" msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), packageRemotePath) message(msg) + downloadUrl(packageRemotePath, packageLocalPath) +} - isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), +downloadUrl <- function(remotePath, localPath) { + isFail <- tryCatch(download.file(remotePath, localPath), error = function(e) { - message(sprintf("Fetch failed from %s", mirrorUrl)) - print(e) + message(e) + message() + TRUE + }, + warning = function(w) { + # Treat warning as error, add an empty line with message() + message(w) + message() TRUE }) !isFail @@ -218,12 +270,11 @@ sparkCachePath <- function() { if (.Platform$OS.type == "windows") { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { - msg <- paste("%LOCALAPPDATA% not found.", + stop(paste("%LOCALAPPDATA% not found.", "Please define the environment variable", - "or restart and enter an installation path in localDir.") - stop(msg) + "or restart and enter an installation path in localDir.")) } else { - path <- file.path(winAppPath, "spark", "spark", "Cache") + path <- file.path(winAppPath, "Apache", "Spark", "Cache") } } else if (.Platform$OS.type == "unix") { if (Sys.info()["sysname"] == "Darwin") { diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7a220b8d53a2..1ddfa30d7fa7 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -116,6 +116,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -124,7 +138,8 @@ setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} @@ -138,7 +153,8 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL @@ -157,7 +173,7 @@ predict_internal <- function(object, newData) { #' Generalized Linear Models #' -#' Fits generalized linear model against a Spark DataFrame. +#' Fits generalized linear model against a SparkDataFrame. #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' @@ -168,6 +184,8 @@ predict_internal <- function(object, newData) { #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance @@ -175,7 +193,7 @@ predict_internal <- function(object, newData) { #' @param regParam regularization parameter for L2 regularization. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method -#' @return \code{spark.glm} returns a fitted generalized linear model +#' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm #' @export @@ -220,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), weightCol <- "" } + # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, + "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter), as.character(weightCol), regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -236,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. @@ -261,10 +282,12 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). #' @param object a fitted generalized linear model. -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients, null/residual deviance, null/residual degrees -#' of freedom, AIC and number of iterations IRLS takes. -#' +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes +#' coefficients, standard error of coefficients, t value and p value), +#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, +#' the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -287,9 +310,18 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), } else { dataFrame(callJMethod(jobj, "rDevianceResiduals")) } - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(coefficients, ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, dispersion = dispersion, null.deviance = null.deviance, deviance = deviance, df.null = df.null, df.residual = df.residual, @@ -301,7 +333,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), # Prints the summary of GeneralizedLinearRegressionModel #' @rdname spark.glm -#' @param x summary object of fitted generalized linear model returned by \code{summary} function +#' @param x summary object of fitted generalized linear model returned by \code{summary} function. #' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { @@ -334,7 +366,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction" +#' "prediction". #' @rdname spark.glm #' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 @@ -348,7 +380,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.naiveBayes #' @export #' @note predict(NaiveBayesModel) since 2.0.0 @@ -360,8 +392,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"), # Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} #' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. -#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{apriori} (the label distribution) and +#' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes #' @export #' @note summary(NaiveBayesModel) since 2.0.0 @@ -382,9 +415,9 @@ setMethod("summary", signature(object = "NaiveBayesModel"), # Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() -#' @param newData A SparkDataFrame for testing +#' @param newData A SparkDataFrame for testing. #' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities -#' vectors named "topicDistribution" +#' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method #' @export @@ -398,7 +431,8 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. #' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. -#' @return \code{summary} returns a list containing +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes #' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for #' the prior placed on documents distributions over topics \code{theta}} #' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or @@ -449,7 +483,7 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr # Saves the Latent Dirichlet Allocation model to the input path. -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -465,19 +499,19 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' Isotonic Regression Model #' -#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). +#' Fits an Isotonic Regression model against a SparkDataFrame, similarly to R's isoreg(). #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or -#' antitonic/decreasing (FALSE) +#' antitonic/decreasing (FALSE). #' @param featureIndex The index of the feature if \code{featuresCol} is a vector column -#' (default: 0), no effect otherwise +#' (default: 0), no effect otherwise. #' @param weightCol The weight column name. #' @param ... additional arguments passed to the method. -#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model +#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg @@ -509,7 +543,7 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' @note spark.isoreg since 2.1.0 setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { - formula <- paste0(deparse(formula), collapse = "") + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -523,9 +557,9 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" # Predicted values based on an isotonicRegression model -#' @param object a fitted IsotonicRegressionModel -#' @param newData SparkDataFrame for testing -#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' @param object a fitted IsotonicRegressionModel. +#' @param newData SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method #' @export @@ -537,7 +571,9 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), # Get the summary of an IsotonicRegressionModel model -#' @return \code{summary} returns the model's boundaries and prediction as lists +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes model's \code{boundaries} (boundaries in increasing order) +#' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method #' @export @@ -552,7 +588,7 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' K-Means Clustering Model #' -#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). +#' Fits a k-means clustering model against a SparkDataFrame, similarly to R's kmeans(). #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' @@ -563,6 +599,10 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @param k number of centers. #' @param maxIter maximum iteration number. #' @param initMode the initialization algorithm choosen to fit the model. +#' @param seed the random seed for cluster initialization +#' @param initSteps the number of steps for the k-means|| initialization mode. +#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. #' @rdname spark.kmeans @@ -592,11 +632,16 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @note spark.kmeans since 2.0.0 #' @seealso \link{predict}, \link{read.ml}, \link{write.ml} setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { + function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"), + seed = NULL, initSteps = 2, tol = 1E-4) { formula <- paste(deparse(formula), collapse = "") initMode <- match.arg(initMode) + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, - as.integer(k), as.integer(maxIter), initMode) + as.integer(k), as.integer(maxIter), initMode, seed, + as.integer(initSteps), as.numeric(tol)) new("KMeansModel", jobj = jobj) }) @@ -634,7 +679,14 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (the configured number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), \code{cluster} +#' (cluster centers of the transformed data), {is.loaded} (whether the model is loaded +#' from a saved file), and \code{clusterSize} +#' (the actual number of cluster centers. When using initMode = "random", +#' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -646,16 +698,17 @@ setMethod("summary", signature(object = "KMeansModel"), coefficients <- callJMethod(jobj, "coefficients") k <- callJMethod(jobj, "k") size <- callJMethod(jobj, "size") - coefficients <- t(matrix(coefficients, ncol = k)) + clusterSize <- callJMethod(jobj, "clusterSize") + coefficients <- t(matrix(coefficients, ncol = clusterSize)) colnames(coefficients) <- unlist(features) - rownames(coefficients) <- 1:k + rownames(coefficients) <- 1:clusterSize cluster <- if (is.loaded) { NULL } else { dataFrame(callJMethod(jobj, "cluster")) } - list(coefficients = coefficients, size = size, - cluster = cluster, is.loaded = is.loaded) + list(k = k, coefficients = coefficients, size = size, + cluster = cluster, is.loaded = is.loaded, clusterSize = clusterSize) }) # Predicted values based on a k-means model @@ -672,22 +725,21 @@ setMethod("predict", signature(object = "KMeansModel"), #' Logistic Regression Model #' -#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression #' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -705,13 +757,10 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". #' @param ... additional arguments passed to the method. -#' @return \code{spark.logit} returns a fitted logistic regression model +#' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit @@ -720,46 +769,36 @@ setMethod("predict", signature(object = "KMeansModel"), #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) -#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) -#' binary_data <- as.data.frame(cbind(label, feature)) -#' binary_df <- createDataFrame(binary_data) -#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) -#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) -#' -#' # summary of binary logistic regression -#' blr_summary <- summary(blr_model) -#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' df <- createDataFrame(iris) +#' training <- df[df$Species %in% c("versicolor", "virginica"), ] +#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' #' # save fitted model to input path #' path <- "path/to/model" -#' write.ml(blr_model, path) +#' write.ml(model, path) #' #' # can also read back the saved model and predict #' # Note that summary deos not work on loaded model #' savedModel <- read.ml(path) -#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' summary(savedModel) #' #' # multinomial logistic regression #' -#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) -#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) -#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) -#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) -#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) -#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) -#' df <- createDataFrame(data) +#' df <- createDataFrame(iris) +#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' summary <- summary(model) #' -#' # Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) -#' predict1 <- collect(select(predict(model, df), "prediction")) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, - tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, - probabilityCol = "probability") { - formula <- paste0(deparse(formula), collapse = "") + tol = 1E-6, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL) { + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -768,10 +807,9 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), - as.numeric(tol), as.logical(fitIntercept), - as.character(family), as.logical(standardization), - as.array(thresholds), as.character(weightCol), - as.integer(aggregationDepth), as.character(probabilityCol)) + as.numeric(tol), as.character(family), + as.logical(standardization), as.array(thresholds), + as.character(weightCol)) new("LogisticRegressionModel", jobj = jobj) }) @@ -790,9 +828,9 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel -#' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -800,33 +838,21 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - - if (is.loaded) { - stop("Loaded model doesn't have training summary.") + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(coefficients, ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) } + rownames(coefficients) <- unlist(features) - roc <- dataFrame(callJMethod(jobj, "roc")) - - areaUnderROC <- callJMethod(jobj, "areaUnderROC") - - pr <- dataFrame(callJMethod(jobj, "pr")) - - fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) - - precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) - - recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) - - totalIterations <- callJMethod(jobj, "totalIterations") - - objectiveHistory <- callJMethod(jobj, "objectiveHistory") - - list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, - fMeasureByThreshold = fMeasureByThreshold, - precisionByThreshold = precisionByThreshold, - recallByThreshold = recallByThreshold, - totalIterations = totalIterations, objectiveHistory = objectiveHistory) + list(coefficients = coefficients) }) #' Multilayer Perceptron Classification Model @@ -840,8 +866,10 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' Multilayer Perceptron} #' #' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param blockSize blockSize parameter. -#' @param layers integer vector containing the number of nodes for each layer +#' @param layers integer vector containing the number of nodes for each layer. #' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. @@ -852,7 +880,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp -#' @aliases spark.mlp,SparkDataFrame-method +#' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} #' @export @@ -861,7 +889,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") #' #' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", #' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, #' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) #' @@ -878,9 +906,10 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' summary(savedModel) #' } #' @note spark.mlp since 2.1.0 -setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, +setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") } @@ -895,7 +924,7 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"), initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", - "fit", data@sdf, as.integer(blockSize), as.array(layers), + "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), as.numeric(stepSize), seed, initialWeights) new("MultilayerPerceptronClassificationModel", jobj = jobj) @@ -918,9 +947,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel # Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} #' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{labelCount}, \code{layers}, and -#' \code{weights}. For \code{weights}, it is a numeric vector with length equal to -#' the expected given the architecture (i.e., for 8-10-2 network, 100 connection weights). +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} +#' (number of outputs), \code{layers} (array of layer sizes including input +#' and output layers), and \code{weights} (the weights of layers). +#' For \code{weights}, it is a numeric vector with length equal to the expected +#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp #' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method @@ -928,10 +960,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), function(object) { jobj <- object@jobj - labelCount <- callJMethod(jobj, "labelCount") layers <- unlist(callJMethod(jobj, "layers")) + numOfInputs <- head(layers, n = 1) + numOfOutputs <- tail(layers, n = 1) weights <- callJMethod(jobj, "weights") - list(labelCount = labelCount, layers = layers, weights = weights) + list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, + layers = layers, weights = weights) }) #' Naive Bayes Models @@ -983,7 +1017,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form # Saves the Bernoulli naive Bayes model to the input path. -#' @param path the directory where the model is saved +#' @param path the directory where the model is saved. #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1057,7 +1091,7 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode # Save fitted IsotonicRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1072,7 +1106,7 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char # Save fitted LogisticRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1141,6 +1175,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1195,14 +1233,14 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param data A SparkDataFrame for training. +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1215,7 +1253,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' parameter if libSVM-format column is used as the features column. #' @param maxVocabSize maximum vocabulary size, default 1 << 18 #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} @@ -1263,8 +1301,9 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, -#' intercept and log(scale) +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{coefficients} (features, coefficients, +#' intercept and log(scale)). #' @rdname spark.survreg #' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 @@ -1284,7 +1323,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0). +#' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg #' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 @@ -1295,7 +1334,7 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' Multivariate Gaussian Mixture Model (GMM) #' -#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's +#' Fits multivariate gaussian mixture model against a SparkDataFrame, similarly to R's #' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, #' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} #' to save/load fitted models. @@ -1351,7 +1390,9 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns summary of the fitted model, which is a list. +#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), +#' \code{sigma} (sigma), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1415,7 +1456,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. #' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. #' @param rank rank of the matrix factorization (> 0). -#' @param reg regularization parameter (>= 0). +#' @param regParam regularization parameter (>= 0). #' @param maxIter maximum number of iterations (>= 0). #' @param nonnegative logical value indicating whether to apply nonnegativity constraints. #' @param implicitPrefs logical value indicating whether to use implicit preference. @@ -1425,7 +1466,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.als} returns a fitted ALS model +#' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als @@ -1454,21 +1495,21 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' # set other arguments #' modelS <- spark.als(df, "rating", "user", "item", rank = 20, -#' reg = 0.1, nonnegative = TRUE) +#' regParam = 0.1, nonnegative = TRUE) #' statsS <- summary(modelS) #' } #' @note spark.als since 2.1.0 setMethod("spark.als", signature(data = "SparkDataFrame"), function(data, ratingCol = "rating", userCol = "user", itemCol = "item", - rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE, + rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, checkpointInterval = 10, seed = 0) { if (!is.numeric(rank) || rank <= 0) { stop("rank should be a positive number.") } - if (!is.numeric(reg) || reg < 0) { - stop("reg should be a nonnegative number.") + if (!is.numeric(regParam) || regParam < 0) { + stop("regParam should be a nonnegative number.") } if (!is.numeric(maxIter) || maxIter <= 0) { stop("maxIter should be a positive number.") @@ -1476,7 +1517,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), - reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, as.integer(numUserBlocks), as.integer(numItemBlocks), as.integer(checkpointInterval), as.integer(seed)) new("ALSModel", jobj = jobj) @@ -1485,9 +1526,11 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), # Returns a summary of the ALS model produced by spark.als. #' @param object a fitted ALS model. -#' @return \code{summary} returns a list containing the names of the user column, -#' the item column and the rating column, the estimated user and item factors, -#' rank, regularization parameter and maximum number of iterations used in training. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{user} (the names of the user column), +#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} +#' (the estimated user factors), \code{itemFactors} (the estimated item factors), +#' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method #' @export @@ -1570,14 +1613,14 @@ setMethod("write.ml", signature(object = "ALSModel", path = "character"), #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) #' df <- createDataFrame(data) -#' test <- spark.ktest(df, "test", "norm", c(0, 1)) +#' test <- spark.kstest(df, "test", "norm", c(0, 1)) #' #' # get a summary of the test result #' testSummary <- summary(test) #' testSummary #' #' # print out the summary in an organized way -#' print.summary.KSTest(test) +#' print.summary.KSTest(testSummary) #' } #' @note spark.kstest since 2.1.0 setMethod("spark.kstest", signature(data = "SparkDataFrame"), @@ -1600,9 +1643,10 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), # Get the summary of Kolmogorov-Smirnov (KS) Test. #' @param object test result object of KSTest by \code{spark.kstest}. -#' @return \code{summary} returns a list containing the p-value, test statistic computed for the -#' test, the null hypothesis with its parameters tested against -#' and degrees of freedom of the test. +#' @return \code{summary} returns summary information of KSTest object, which is a list. +#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic +#' computed for the test), \code{nullHypothesis} (the null hypothesis with its +#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method #' @export @@ -1644,33 +1688,36 @@ print.summary.KSTest <- function(x, ...) { #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxDepth Maximum depth of the tree (>= 0). #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing #' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' >= 2 and >= number of categories in any categorical feature. #' @param numTrees Number of trees to train (>= 1). #' @param impurity Criterion used for information gain calculation. #' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini". (default = gini) -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. #' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. (default = 1.0) -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. (default = "probability") +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -1703,9 +1750,9 @@ print.summary.KSTest <- function(x, ...) { setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -1734,7 +1781,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo impurity, as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), as.character(probabilityCol), + as.numeric(subsamplingRate), as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) new("RandomForestClassificationModel", jobj = jobj) } @@ -1745,11 +1792,11 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method #' @export -#' @note predict(randomForestRegressionModel) since 2.1.0 +#' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { predict_internal(object, newData) @@ -1758,7 +1805,7 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method #' @export -#' @note predict(randomForestClassificationModel) since 2.1.0 +#' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { predict_internal(object, newData) @@ -1766,8 +1813,8 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), # Save the Random Forest Regression or Classification model to the input path. -#' @param object A fitted Random Forest regression model or classification model -#' @param path The directory where the model is saved +#' @param object A fitted Random Forest regression model or classification model. +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1789,8 +1836,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path write_internal(object, path, overwrite) }) -# Get the summary of an RandomForestRegressionModel model -summary.randomForest <- function(model) { +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { jobj <- model@jobj formula <- callJMethod(jobj, "formula") numFeatures <- callJMethod(jobj, "numFeatures") @@ -1807,20 +1854,25 @@ summary.randomForest <- function(model) { jobj = jobj) } -#' @return \code{summary} returns the model's features as lists, depth and number of nodes -#' or number of classes. +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestRegressionModel" ans }) -# Get the summary of an RandomForestClassificationModel model +# Get the summary of a Random Forest Classification Model #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method @@ -1828,13 +1880,13 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestClassificationModel" ans }) -# Prints the summary of Random Forest Regression Model -print.summary.randomForest <- function(x) { +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { jobj <- x$jobj cat("Formula: ", x$formula) cat("\nNumber of features: ", x$numFeatures) @@ -1848,13 +1900,15 @@ print.summary.randomForest <- function(x) { invisible(x) } +# Prints the summary of Random Forest Regression Model + #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest #' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) } # Prints the summary of Random Forest Classification Model @@ -1863,5 +1917,216 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) } diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 4dee3245f9b7..8fa21be3076b 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -780,7 +780,7 @@ setMethod("cogroup", #' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitionsRDD(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -850,7 +850,7 @@ setMethod("sortByKey", #' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 6b4a2f2fdc85..d0a12b7ecec6 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -323,6 +323,18 @@ sparkRHive.init <- function(jsc = NULL) { #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. #' +#' When called in an interactive session, this method checks for the Spark installation, and, if not +#' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can +#' be called manually. +#' +#' A default warehouse is created automatically in the current directory when a managed table is +#' created via \code{sql} statement \code{CREATE TABLE}, for example. To change the location of the +#' warehouse, set the named parameter \code{spark.sql.warehouse.dir} to the SparkSession. Along with +#' the warehouse, an accompanied metastore may also be automatically created in the current +#' directory when a new SparkSession is initialized with \code{enableHiveSupport} set to +#' \code{TRUE}, which is the default. For more details, refer to Hive configuration at +#' \url{http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables}. +#' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' @@ -373,8 +385,17 @@ sparkR.session <- function( overrideEnvs(sparkConfigMap, paramMap) } + deployMode <- "" + if (exists("spark.submit.deployMode", envir = sparkConfigMap)) { + deployMode <- sparkConfigMap[["spark.submit.deployMode"]] + } + + if (!exists("spark.r.sql.derby.temp.dir", envir = sparkConfigMap)) { + sparkConfigMap[["spark.r.sql.derby.temp.dir"]] <- tempdir() + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { - retHome <- sparkCheckInstall(sparkHome, master) + retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, @@ -402,6 +423,30 @@ sparkR.session <- function( sparkSession } +#' Get the URL of the SparkUI instance for the current active SparkSession +#' +#' Get the URL of the SparkUI instance for the current active SparkSession. +#' +#' @return the SparkUI URL, or NA if it is disabled, or not started. +#' @rdname sparkR.uiWebUrl +#' @name sparkR.uiWebUrl +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' url <- sparkR.uiWebUrl() +#' } +#' @note sparkR.uiWebUrl since 2.1.1 +sparkR.uiWebUrl <- function() { + sc <- sparkR.callJMethod(getSparkContext(), "sc") + u <- callJMethod(sc, "uiWebUrl") + if (callJMethod(u, "isDefined")) { + callJMethod(u, "get") + } else { + NA + } +} + #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' @@ -419,7 +464,7 @@ sparkR.session <- function( #' @method setJobGroup default setJobGroup.default <- function(groupId, description, interruptOnCancel) { sc <- getSparkContext() - callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) + invisible(callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)) } setJobGroup <- function(sc, groupId, description, interruptOnCancel) { @@ -449,7 +494,7 @@ setJobGroup <- function(sc, groupId, description, interruptOnCancel) { #' @method clearJobGroup default clearJobGroup.default <- function() { sc <- getSparkContext() - callJMethod(sc, "clearJobGroup") + invisible(callJMethod(sc, "clearJobGroup")) } clearJobGroup <- function(sc) { @@ -476,7 +521,7 @@ clearJobGroup <- function(sc) { #' @method cancelJobGroup default cancelJobGroup.default <- function(groupId) { sc <- getSparkContext() - callJMethod(sc, "cancelJobGroup", groupId) + invisible(callJMethod(sc, "cancelJobGroup", groupId)) } cancelJobGroup <- function(sc, groupId) { @@ -550,24 +595,25 @@ processSparkPackages <- function(packages) { # # @param sparkHome directory to find Spark package. # @param master the Spark master URL, used to check local or remote mode. +# @param deployMode whether to deploy your driver on the worker nodes (cluster) +# or locally as an external client (client). # @return NULL if no need to update sparkHome, and new sparkHome otherwise. -sparkCheckInstall <- function(sparkHome, master) { +sparkCheckInstall <- function(sparkHome, master, deployMode) { if (!isSparkRShell()) { if (!is.na(file.info(sparkHome)$isdir)) { - msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) - message(msg) + message("Spark package found in SPARK_HOME: ", sparkHome) NULL } else { - if (!nzchar(master) || isMasterLocal(master)) { - msg <- paste0("Spark not found in SPARK_HOME: ", - sparkHome) - message(msg) + if (interactive() || isMasterLocal(master)) { + message("Spark not found in SPARK_HOME: ", sparkHome) packageLocalDir <- install.spark() packageLocalDir - } else { + } else if (isClientMode(master) || deployMode == "client") { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome, "\n", installInstruction("remote")) stop(msg) + } else { + NULL } } } else { diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index abca703617c7..ade0f05c0254 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -29,7 +29,7 @@ PRIMITIVE_TYPES <- as.environment(list( "string" = "character", "binary" = "raw", "boolean" = "logical", - "timestamp" = "POSIXct", + "timestamp" = c("POSIXct", "POSIXt"), "date" = "Date", # following types are not SQL types returned by dtypes(). They are listed here for usage # by checkType() in schema.R. diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 20004549cc03..1f7848f2b413 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -756,12 +756,17 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, capture = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE) { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") - shell(scriptWithArgs, translate = TRUE, wait = capture, intern = capture) # nolint + # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) + shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait) # nolint } else { - system2(script, combinedArgs, wait = capture, stdout = capture) + # http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html + # stdout = F means discard output + # stdout = "" means to its console (default) + # Note that the console of this child process might not be the same as the running R process. + system2(script, combinedArgs, stdout = "", wait = wait) } } @@ -777,6 +782,10 @@ isMasterLocal <- function(master) { grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) } +isClientMode <- function(master) { + grepl("([a-z]+)-client$", master, perl = TRUE) +} + isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } @@ -837,7 +846,7 @@ captureJVMException <- function(e, method) { # # @param inputData a list of rows, with each row a list # @return data.frame with raw columns as lists -rbindRaws <- function(inputData){ +rbindRaws <- function(inputData) { row1 <- inputData[[1]] rawcolumns <- ("raw" == sapply(row1, class)) @@ -847,3 +856,19 @@ rbindRaws <- function(inputData){ out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) out } + +# Get basename without extension from URL +basenameSansExtFromUrl <- function(url) { + # split by '/' + splits <- unlist(strsplit(url, "^.+/")) + last <- tail(splits, 1) + # this is from file_path_sans_ext + # first, remove any compression extension + filename <- sub("[.](gz|bz2|xz)$", "", last) + # then, strip extension by the last '.' + sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename) +} + +isAtomicLengthOne <- function(x) { + is.atomic(x) && length(x) == 1 +} diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 8813e18a1fa4..1d777ddb286d 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -20,7 +20,8 @@ test_that("sparkJars tag in SparkContext", { if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } - testOutput <- launchScript("ECHO", "a/b/c", capture = TRUE) + + testOutput <- launchScript("ECHO", "a/b/c", wait = TRUE) abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index caca06933952..c84711349111 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -177,6 +177,13 @@ test_that("add and get file to be downloaded with Spark job on every node", { spark.addFile(path) download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) + + # Test spark.getSparkFiles works well on executors. + seq <- seq(from = 1, to = 10, length.out = 5) + f <- function(seq) { spark.getSparkFiles(filename) } + results <- spark.lapply(seq, f) + for (i in 1:5) { expect_equal(basename(results[[i]]), filename) } + unlink(path) # Test add directory recursively. diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index db98d0e45547..8fe3a87f6faa 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -74,6 +74,14 @@ test_that("spark.glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) @@ -159,6 +167,15 @@ test_that("spark.glm summary", { df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- unlist(stats$coefficients) + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) }) test_that("spark.glm save/load", { @@ -341,6 +358,8 @@ test_that("spark.kmeans", { # Test summary works on KMeans summary.model <- summary(model) cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 2) expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) # Test model save/load @@ -356,17 +375,45 @@ test_that("spark.kmeans", { expect_true(summary2$is.loaded) unlink(modelPath) + + # Test Kmeans on dataset that is sensitive to seed value + col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + cols <- as.data.frame(cbind(col1, col2, col3)) + df <- createDataFrame(cols) + + model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 1, tol = 1E-5) + model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 22222, tol = 1E-5) + + summary.model1 <- summary(model1) + summary.model2 <- summary(model2) + cluster1 <- summary.model1$cluster + cluster2 <- summary.model2$cluster + clusterSize1 <- summary.model1$clusterSize + clusterSize2 <- summary.model2$clusterSize + + # The predicted clusters are different + expect_equal(sort(collect(distinct(select(cluster1, "prediction")))$prediction), + c(0, 1, 2, 3)) + expect_equal(sort(collect(distinct(select(cluster2, "prediction")))$prediction), + c(0, 1, 2)) + expect_equal(clusterSize1, 4) + expect_equal(clusterSize2, 3) }) test_that("spark.mlp", { df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") - model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) + model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), + solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) # Test summary method summary <- summary(model) - expect_equal(summary$labelCount, 3) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 5, 4, 3)) expect_equal(length(summary$weights), 64) expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), @@ -375,7 +422,7 @@ test_that("spark.mlp", { # Test predict method mlpTestDF <- df mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1)) + expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") @@ -385,46 +432,68 @@ test_that("spark.mlp", { model2 <- read.ml(modelPath) summary2 <- summary(model2) - expect_equal(summary2$labelCount, 3) + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) expect_equal(summary2$layers, c(4, 5, 4, 3)) expect_equal(length(summary2$weights), 64) unlink(modelPath) # Test default parameter - model <- spark.mlp(df, layers = c(4, 5, 4, 3)) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test illegal parameter - expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = NULL), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c()), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c(3)), + "layers must be a integer vector with length > 1.") # Test random seed # default seed - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # seed equals 10 - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + + # Test formula works well + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + layers = c(4, 3)) + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 3)) + expect_equal(length(summary$weights), 15) + expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, + -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { @@ -603,58 +672,141 @@ test_that("spark.isotonicRegression", { }) test_that("spark.logit", { - # test binary logistic regression - label <- c(1.0, 1.0, 1.0, 0.0, 0.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - binary_data <- as.data.frame(cbind(label, feature)) - binary_df <- createDataFrame(binary_data) - - blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) - blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) - expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0)) - blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) - blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) - expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1)) - - # test summary of binary logistic regression - blr_summary <- summary(blr_model) - blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) - expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487), - tolerance = 1e-4) - expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000), - tolerance = 1e-4) - blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) - expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000), - tolerance = 1e-4) - blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) - expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000), - tolerance = 1e-4) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end - # test model save and read - modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") - write.ml(blr_model, modelPath) - expect_error(write.ml(blr_model, modelPath)) - write.ml(blr_model, modelPath, overwrite = TRUE) - blr_model2 <- read.ml(modelPath) - blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) - expect_equal(blr_predict$prediction, blr_predict2$prediction) - expect_error(summary(blr_model2)) + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + setosaCoefs <- unlist(summary$coefficients[, "setosa"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) unlink(modelPath) - # test multinomial logistic regression - label <- c(0.0, 1.0, 2.0, 0.0, 0.0) - feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) - feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) - feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) - feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) - data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) - df <- createDataFrame(data) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- unlist(summary$coefficients[, "Estimate"]) + expect_true(all(abs(coefsR - coefs) < 0.1)) + + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) - model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) - predict1 <- collect(select(predict(model, df), "prediction")) - expect_equal(predict1$prediction, c(0, 0, 0, 0, 0)) - # Summary of multinomial logistic regression is not implemented yet - expect_error(summary(model)) + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) }) test_that("spark.gaussianMixture", { @@ -735,7 +887,7 @@ test_that("spark.lda with libsvm", { weights <- stats$topicTopTermsWeights vocabulary <- stats$vocabulary - expect_false(isDistributed) + expect_true(isDistributed) expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) expect_equal(vocabSize, 11) @@ -749,7 +901,7 @@ test_that("spark.lda with libsvm", { model2 <- read.ml(modelPath) stats2 <- summary(model2) - expect_false(stats2$isDistributed) + expect_true(stats2$isDistributed) expect_equal(logLikelihood, stats2$logLikelihood) expect_equal(logPerplexity, stats2$logPerplexity) expect_equal(vocabSize, stats2$vocabSize) @@ -811,10 +963,10 @@ test_that("spark.posterior and spark.perplexity", { test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), - list(2, 1, 1.0), list(2, 2, 5.0)) + list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(data, c("user", "item", "score")) model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", - rank = 10, maxIter = 5, seed = 0, reg = 0.1) + rank = 10, maxIter = 5, seed = 0, regParam = 0.1) stats <- summary(model) expect_equal(stats$rank, 10) test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) @@ -869,9 +1021,16 @@ test_that("spark.kstest", { expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + # Test print.summary.KSTest + printStats <- capture.output(print.summary.KSTest(stats)) + expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") + expect_match(printStats[5], + "Low presumption against null hypothesis: Sample follows theoretical distribution. ") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -891,10 +1050,11 @@ test_that("spark.randomForest Regression", { model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 20, seed = 123) predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, - 63.736, 64.296, 64.868, 64.300, - 66.709, 67.697, 67.966, 67.252, - 68.866, 69.593, 69.195, 69.658), + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) @@ -913,9 +1073,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -925,6 +1084,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -937,6 +1100,106 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index a3d66c245a7d..787ef51c501c 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -29,8 +29,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(getNumPartitions(rdd), 2) - expect_equal(getNumPartitions(intRdd), 2) + expect_equal(getNumPartitionsRDD(rdd), 2) + expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { @@ -305,18 +305,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartitionRDD(rdd, 2) - expect_equal(getNumPartitions(r1), 2L) + expect_equal(getNumPartitionsRDD(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartitionRDD(rdd, 6) - expect_equal(getNumPartitions(r2), 6L) + expect_equal(getNumPartitionsRDD(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce - r3 <- coalesce(rdd, 1) - expect_equal(getNumPartitions(r3), 1L) + r3 <- coalesceRDD(rdd, 1) + expect_equal(getNumPartitionsRDD(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) @@ -381,8 +381,8 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 3), list("c", 1), - list("d", 4), list("e", 2)) + expected <- list(list("a", 0), list("b", 1), list("c", 4), + list("d", 2), list("e", 5)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R new file mode 100644 index 000000000000..f73fc6baecce --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in sparkR.R") + +test_that("sparkCheckInstall", { + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- paste0(tempdir(), "/", "sparkHome") + dir.create(sparkHome) + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + unlink(sparkHome, recursive = TRUE) + + # "yarn-cluster, mesos-cluster" mode, SPARK_HOME was not set, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- "" + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + + # "yarn-client, mesos-client" mode, SPARK_HOME was not set + sparkHome <- "" + master <- "yarn-client" + deployMode <- "" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) + sparkHome <- "" + master <- "" + deployMode <- "client" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) +}) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524f..628512440d6e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,6 +60,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR +filesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) @@ -196,6 +197,26 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(as.list(collect(where(df, df$name == "John"))), list(name = "John", age = 19L, height = 176.5)) + expect_equal(getNumPartitions(df), 1) + + df <- as.DataFrame(cars, numPartitions = 2) + expect_equal(getNumPartitions(df), 2) + df <- createDataFrame(cars, numPartitions = 3) + expect_equal(getNumPartitions(df), 3) + # validate limit by num of rows + df <- createDataFrame(cars, numPartitions = 60) + expect_equal(getNumPartitions(df), 50) + # validate when 1 < (length(coll) / numSlices) << length(coll) + df <- createDataFrame(cars, numPartitions = 20) + expect_equal(getNumPartitions(df), 20) + + df <- as.DataFrame(data.frame(0)) + expect_is(df, "SparkDataFrame") + df <- createDataFrame(list(list(1))) + expect_is(df, "SparkDataFrame") + df <- as.DataFrame(data.frame(0), numPartitions = 2) + # no data to partition, goes to 1 + expect_equal(getNumPartitions(df), 1) setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") @@ -205,6 +226,7 @@ test_that("create DataFrame from RDD", { c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + sql("DROP TABLE people") unsetHiveContext() }) @@ -212,7 +234,8 @@ test_that("createDataFrame uses files for large objects", { # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") - df <- suppressWarnings(createDataFrame(iris)) + df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) + expect_equal(getNumPartitions(df), 3) # Resetting the conf back to default value callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) @@ -576,7 +599,7 @@ test_that("test tableNames and tables", { tables <- tables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) - dropTempView("table2") + expect_true(dropTempView("table2")) tables <- tables() expect_equal(count(tables), 0) @@ -589,7 +612,7 @@ test_that( newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(df, "dfView") sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) @@ -600,7 +623,7 @@ test_that( expect_equal(ncol(sqlCast), 1) expect_equal(out[1], " x") expect_equal(out[2], "1 2") - dropTempView("dfView") + expect_true(dropTempView("dfView")) }) test_that("test cache, uncache and clearCache", { @@ -609,7 +632,7 @@ test_that("test cache, uncache and clearCache", { cacheTable("table1") uncacheTable("table1") clearCache() - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("insertInto() on a registered table", { @@ -630,13 +653,13 @@ test_that("insertInto() on a registered table", { insertInto(dfParquet2, "table1") expect_equal(count(sql("select * from table1")), 5) expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) expect_equal(count(sql("select * from table1")), 2) expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - dropTempView("table1") + expect_true(dropTempView("table1")) unlink(jsonPath2) unlink(parquetPath2) @@ -650,7 +673,7 @@ test_that("tableToDF() returns a new DataFrame", { expect_equal(count(tabledf), 3) tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("toRDD() returns an RRDD", { @@ -703,7 +726,7 @@ test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) - saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + saveAsObjectFile(coalesceRDD(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") @@ -985,6 +1008,18 @@ test_that("select operators", { expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") + expect_warning(df[[1:2]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[1:2]]), "Column") + expect_warning(df[[c("name", "age")]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[c("name", "age")]]), "Column") + + expect_warning(df[[1:2]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_warning(df[[c("name", "age")]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_is(df[, 1, drop = F], "SparkDataFrame") expect_equal(columns(df[, 1, drop = F]), c("name")) expect_equal(columns(df[, "age", drop = F]), c("age")) @@ -999,6 +1034,37 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + df$age2 <- df[["age"]] * 3 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 3)), 2) + + df$age2 <- 21 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 21)), 3) + + df$age2 <- c(22) + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 22)), 3) + + expect_error(df$age3 <- c(22, NA), + "value must be a Column, literal value as atomic in length of 1, or NULL") + + df[["age2"]] <- 23 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 23)), 3) + + df[[3]] <- 24 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 24)), 3) + + df[[3]] <- df$age + expect_equal(count(where(df, df$age2 == df$age)), 2) + + df[["age2"]] <- df[["name"]] + expect_equal(count(where(df, df$age2 == df$name)), 3) + + expect_error(df[["age3"]] <- c(22, 23), + "value must be a Column, literal value as atomic in length of 1, or NULL") # Test parameter drop expect_equal(class(df[, 1]) == "SparkDataFrame", T) @@ -1175,7 +1241,7 @@ test_that("column functions", { c16 <- is.nan(c) + isnan(c) + isNaN(c) c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") c18 <- covar_pop(c, c1) + covar_pop("c", "c1") - c19 <- spark_partition_id() + c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1222,16 +1288,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -1244,9 +1310,9 @@ test_that("column functions", { # Test first(), last() df <- read.json(jsonPath) - expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age)))[[1]], NA_real_) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) - expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age")))[[1]], NA_real_) expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) expect_equal(collect(select(df, last(df$age)))[[1]], 19) expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) @@ -1776,6 +1842,13 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(length(columns(newDF)), 2) expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + newDF <- withColumn(df, "age", 18) + expect_equal(length(columns(newDF)), 2) + expect_equal(first(newDF)$age, 18) + + expect_error(withColumn(df, "age", list("a")), + "Literal value must be atomic in length of 1") + newDF2 <- withColumnRenamed(df, "age", "newerAge") expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") @@ -2421,15 +2494,18 @@ test_that("repartition by columns on DataFrame", { ("Please, specify the number of partitions and/or a column\\(s\\)", retError), TRUE) # repartition by column and number of partitions - actual <- repartition(df, 3L, col = df$"a") + actual <- repartition(df, 3, col = df$"a") - # since we cannot access the number of partitions from dataframe, checking - # that at least the dimensions are identical + # Checking that at least the dimensions are identical expect_identical(dim(df), dim(actual)) + expect_equal(getNumPartitions(actual), 3L) # repartition by number of partitions actual <- repartition(df, 13L) expect_identical(dim(df), dim(actual)) + expect_equal(getNumPartitions(actual), 13L) + + expect_equal(getNumPartitions(coalesce(actual, 1L)), 1L) # a test case with a column and dapply schema <- structType(structField("a", "integer"), structField("avg", "double")) @@ -2445,6 +2521,25 @@ test_that("repartition by columns on DataFrame", { expect_equal(nrow(df1), 2) }) +test_that("coalesce, repartition, numPartitions", { + df <- as.DataFrame(cars, numPartitions = 5) + expect_equal(getNumPartitions(df), 5) + expect_equal(getNumPartitions(coalesce(df, 3)), 3) + expect_equal(getNumPartitions(coalesce(df, 6)), 5) + + df1 <- coalesce(df, 3) + expect_equal(getNumPartitions(df1), 3) + expect_equal(getNumPartitions(coalesce(df1, 6)), 5) + expect_equal(getNumPartitions(coalesce(df1, 4)), 4) + expect_equal(getNumPartitions(coalesce(df1, 2)), 2) + + df2 <- repartition(df1, 10) + expect_equal(getNumPartitions(df2), 10) + expect_equal(getNumPartitions(coalesce(df2, 13)), 5) + expect_equal(getNumPartitions(coalesce(df2, 7)), 5) + expect_equal(getNumPartitions(coalesce(df2, 3)), 3) +}) + test_that("gapply() and gapplyCollect() on a DataFrame", { df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), @@ -2612,7 +2707,7 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) -test_that("Setting and getting config on SparkSession", { +test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiWebUrl()", { # first, set it to a random but known value conf <- callJMethod(sparkSession, "conf") property <- paste0("spark.testing.", as.character(runif(1))) @@ -2636,6 +2731,9 @@ test_that("Setting and getting config on SparkSession", { expect_equal(appNameValue, "sparkSession test") expect_equal(testValue, value) expect_error(sparkR.conf("completely.dummy"), "Config 'completely.dummy' is not set") + + url <- sparkR.uiWebUrl() + expect_equal(substr(url, 1, 7), "http://") }) test_that("enableHiveSupport on SparkSession", { @@ -2659,7 +2757,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") expect_error(write.json(df, jsonPath), "Error in json : analysis error - path file:.*already exists") expect_error(write.text(df, jsonPath), @@ -2667,7 +2765,7 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume expect_error(write.orc(df, jsonPath), "Error in orc : analysis error - path file:.*already exists") expect_error(write.parquet(df, jsonPath), - "Error in parquet : analysis error - path file:.*already exists") + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2684,7 +2782,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") @@ -2704,6 +2802,79 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { + ldf <- data.frame(col1 = c(0, 1, 2), + col2 = c(as.POSIXct("2017-01-01 00:00:01"), + NA, + as.POSIXct("2017-01-01 12:00:01")), + col3 = c(as.POSIXlt("2016-01-01 00:59:59"), + NA, + as.POSIXlt("2016-01-01 12:01:01"))) + sdf1 <- createDataFrame(ldf) + ldf1 <- collect(sdf1) + expect_equal(dtypes(sdf1), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf1$col1), "numeric") + expect_equal(class(ldf1$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf1$col3), c("POSIXct", "POSIXt")) + + # Columns with NAs at the top + sdf2 <- filter(sdf1, "col1 > 1") + ldf2 <- collect(sdf2) + expect_equal(dtypes(sdf2), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf2$col1), "numeric") + expect_equal(class(ldf2$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf2$col3), c("POSIXct", "POSIXt")) + + # Columns with only NAs, the type will also be cast to PRIMITIVE_TYPE + sdf3 <- filter(sdf1, "col1 == 0") + ldf3 <- collect(sdf3) + expect_equal(dtypes(sdf3), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf3$col1), "numeric") + expect_equal(class(ldf3$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) +}) + +compare_list <- function(list1, list2) { + # get testthat to show the diff by first making the 2 lists equal in length + expect_equal(length(list1), length(list2)) + l <- max(length(list1), length(list2)) + length(list1) <- l + length(list2) <- l + expect_equal(sort(list1, na.last = TRUE), sort(list2, na.last = TRUE)) +} + +# This should always be the **very last test** in this test file. +test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() + + # Check that it is not creating any extra file. + # Does not check the tempdir which would be cleaned up after. + filesAfter <- list.files(path = sparkRDir, all.files = TRUE) + + expect_true(length(sparkRFilesBefore) > 0) + # first, ensure derby.log is not there + expect_false("derby.log" %in% filesAfter) + # second, ensure only spark-warehouse is created when calling SparkSession, enableHiveSupport = F + # note: currently all other test files have enableHiveSupport = F, so we capture the list of files + # before creating a SparkSession with enableHiveSupport = T at the top of this test file + # (filesBefore). The test here is to compare that (filesBefore) against the list of files before + # any test is run in run-all.R (sparkRFilesBefore). + # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the + # same as before any test is run. + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T + # note: as the note above, after running all tests in this file while enableHiveSupport = T, we + # check the list of files again. This time we allow both whitelisted dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 607c407f04f9..c87524842876 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -228,4 +228,15 @@ test_that("varargsToStrEnv", { expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) +test_that("basenameSansExtFromUrl", { + x <- paste0("http://people.apache.org/~pwendell/spark-nightly/spark-branch-2.1-bin/spark-2.1.1-", + "SNAPSHOT-2016_12_09_11_08-eb2d9bf-bin/spark-2.1.1-SNAPSHOT-bin-hadoop2.7.tgz") + y <- paste0("http://people.apache.org/~pwendell/spark-releases/spark-2.1.0-rc2-bin/spark-2.1.0-", + "bin-hadoop2.4-without-hive.tgz") + expect_equal(basenameSansExtFromUrl(x), "spark-2.1.1-SNAPSHOT-bin-hadoop2.7") + expect_equal(basenameSansExtFromUrl(y), "spark-2.1.0-bin-hadoop2.4-without-hive") + z <- "http://people.apache.org/~pwendell/spark-releases/spark-2.1.0--hive.tar.gz" + expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") +}) + sparkR.session.stop() diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 1d04656ac259..9b75d9556692 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,4 +21,13 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +install.spark() + +# Setup global test environment +sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") +invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bdd..d16526d306d6 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,14 +1,32 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- + + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). @@ -26,6 +44,9 @@ library(SparkR) We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). +```{r, include=FALSE} +install.spark() +``` ```{r, message=FALSE, results="hide"} sparkR.session() ``` @@ -93,13 +114,13 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() ``` -If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the existing Spark installation is. ```{r, eval=FALSE} sparkR.session(sparkHome = "/HOME/spark") @@ -446,25 +467,43 @@ head(teenagers) SparkR supports the following machine learning models and algorithms. -* Generalized Linear Model (GLM) +#### Classification -* Naive Bayes Model +* Logistic Regression -* $k$-means Clustering +* Multilayer Perceptron (MLP) + +* Naive Bayes + +#### Regression * Accelerated Failure Time (AFT) Survival Model +* Generalized Linear Model (GLM) + +* Isotonic Regression + +#### Tree - Classification and Regression + +* Gradient-Boosted Trees (GBT) + +* Random Forest + +#### Clustering + * Gaussian Mixture Model (GMM) +* $k$-means Clustering + * Latent Dirichlet Allocation (LDA) -* Multilayer Perceptron Model +#### Collaborative Filtering -* Collaborative Filtering with Alternating Least Squares (ALS) +* Alternating Least Squares (ALS) -* Isotonic Regression Model +#### Statistics -More will be added in the future. +* Kolmogorov-Smirnov Test ### R Formula @@ -489,9 +528,115 @@ count(carsDF_test) head(carsDF_test) ``` - ### Models and Algorithms +#### Logistic Regression + +[Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) is a widely-used model when the response is categorical. It can be seen as a special case of the [Generalized Linear Predictive Model](https://en.wikipedia.org/wiki/Generalized_linear_model). +We provide `spark.logit` on top of `spark.glm` to support logistic regression with advanced hyper-parameters. +It supports both binary and multiclass classification with elastic-net regularization and feature standardization, similar to `glmnet`. + +We use a simple example to demonstrate `spark.logit` usage. In general, there are three steps of using `spark.logit`: +1). Create a dataframe from a proper data source; 2). Fit a logistic regression model using `spark.logit` with a proper parameter setting; +and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. + +Binomial logistic regression +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Create a DataFrame containing two classes +training <- df[df$Species %in% c("versicolor", "virginica"), ] +model <- spark.logit(training, Species ~ ., regParam = 0.00042) +summary(model) +``` + +Predict values on training data +```{r} +fitted <- predict(model, training) +``` + +Multinomial logistic regression against three classes +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. +model <- spark.logit(df, Species ~ ., regParam = 0.056) +summary(model) +``` + +#### Multilayer Perceptron + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: +$$ +y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). +$$ + +Nodes in intermediate layers use sigmoid (logistic) function: +$$ +f(z_i) = \frac{1}{1+e^{-z_i}}. +$$ + +Nodes in the output layer use softmax function: +$$ +f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. +$$ + +The number of nodes $N$ in the output layer corresponds to the number of classes. + +MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. + +`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. + +We use iris data set to show how to use `spark.mlp` in classification. +```{r, warning=FALSE} +df <- createDataFrame(iris) +# fit a Multilayer Perceptron Classification Model +model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +``` + +To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. +```{r, include=FALSE} +ops <- options() +options(max.print=5) +``` +```{r} +# check the summary of the fitted model +summary(model) +``` +```{r, include=FALSE} +options(ops) +``` +```{r} +# make predictions use the fitted model +predictions <- predict(model, df) +head(select(predictions, predictions$prediction)) +``` + +#### Naive Bayes + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### Accelerated Failure Time Survival Model + +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. +```{r, warning=FALSE} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + #### Generalized Linear Model The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. @@ -525,46 +670,78 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` -#### Naive Bayes Model +#### Isotonic Regression -Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. +`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize +$$ +\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. +$$ + +There are a few more arguments that may be useful. + +* `weightCol`: a character string specifying the weight column. + +* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). + +* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. + +We use an artificial example to show the use. ```{r} -titanic <- as.data.frame(Titanic) -titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) -naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) -summary(naiveBayesModel) -naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) -head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +y <- c(3.0, 6.0, 8.0, 5.0, 7.0) +x <- c(1.0, 2.0, 3.5, 3.0, 4.0) +w <- rep(1.0, 5) +data <- data.frame(y = y, x = x, w = w) +df <- createDataFrame(data) +isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") +isoregFitted <- predict(isoregModel, df) +head(select(isoregFitted, "x", "y", "prediction")) ``` -#### k-Means Clustering +In the prediction stage, based on the fitted monotone piecewise function, the rules are: -`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. +* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. + +* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. + +* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. + +For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. ```{r} -kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) -summary(kmeansModel) -kmeansPredictions <- predict(kmeansModel, carsDF) -head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) +head(predict(isoregModel, newDF)) ``` -#### AFT Survival Model -Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. +#### Gradient-Boosted Trees + +`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: -Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. ```{r, warning=FALSE} -library(survival) -ovarianDF <- createDataFrame(ovarian) -aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) -summary(aftModel) -aftPredictions <- predict(aftModel, ovarianDF) -head(aftPredictions) +df <- createDataFrame(longley) +gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +summary(gbtModel) +predictions <- predict(gbtModel, df) ``` -#### Gaussian Mixture Model +#### Random Forest + +`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +In the following example, we use the `longley` dataset to train a random forest and make predictions: -(Coming in 2.1.0) +```{r, warning=FALSE} +df <- createDataFrame(longley) +rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +summary(rfModel) +predictions <- predict(rfModel, df) +``` + +#### Gaussian Mixture Model `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. @@ -580,10 +757,18 @@ gmmFitted <- predict(gmmModel, df) head(select(gmmFitted, "V1", "V2", "prediction")) ``` +#### k-Means Clustering -#### Latent Dirichlet Allocation +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. + +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` -(Coming in 2.1.0) +#### Latent Dirichlet Allocation `spark.lda` fits a [Latent Dirichlet Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model on a `SparkDataFrame`. It is often used in topic modeling in which topics are inferred from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: @@ -599,22 +784,6 @@ To use LDA, we need to specify a `features` column in `data` where each entry re * libSVM: Each entry is a collection of words and will be processed directly. -There are several parameters LDA takes for fitting the model. - -* `k`: number of topics (default 10). - -* `maxIter`: maximum iterations (default 20). - -* `optimizer`: optimizer to train an LDA model, "online" (default) uses [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf). "em" uses [expectation-maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm). - -* `subsamplingRate`: For `optimizer = "online"`. Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1] (default 0.05). - -* `topicConcentration`: concentration parameter (commonly named beta or eta) for the prior placed on topic distributions over terms, default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective topicConcentration. Only 1-size numeric is accepted. - -* `docConcentration`: concentration parameter (commonly named alpha) for the prior placed on documents distributions over topics (theta), default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective docConcentration. Only 1-size or k-size numeric is accepted. - -* `maxVocabSize`: maximum vocabulary size, default 1 << 18. - Two more functions are provided for the fitted model. * `spark.posterior` returns a `SparkDataFrame` containing a column of posterior probabilities vectors named "topicDistribution". @@ -653,47 +822,7 @@ perplexity <- spark.perplexity(model, corpusDF) perplexity ``` - -#### Multilayer Perceptron - -(Coming in 2.1.0) - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: -$$ -y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). -$$ - -Nodes in intermediate layers use sigmoid (logistic) function: -$$ -f(z_i) = \frac{1}{1+e^{-z_i}}. -$$ - -Nodes in the output layer use softmax function: -$$ -f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. -$$ - -The number of nodes $N$ in the output layer corresponds to the number of classes. - -MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. - -`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. According to the description above, there are several additional parameters that can be set: - -* `layers`: integer vector containing the number of nodes for each layer. - -* `solver`: solver parameter, supported options: `"gd"` (minibatch gradient descent) or `"l-bfgs"`. - -* `maxIter`: maximum iteration number. - -* `tol`: convergence tolerance of iterations. - -* `stepSize`: step size for `"gd"`. - -* `seed`: seed parameter for weights initialization. - -#### Collaborative Filtering - -(Coming in 2.1.0) +#### Alternating Least Squares `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). @@ -722,53 +851,28 @@ predicted <- predict(model, df) head(predicted) ``` -#### Isotonic Regression Model - -(Coming in 2.1.0) - -`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize -$$ -\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. -$$ - -There are a few more arguments that may be useful. - -* `weightCol`: a character string specifying the weight column. - -* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). - -* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. - -We use an artificial example to show the use. - -```{r} -y <- c(3.0, 6.0, 8.0, 5.0, 7.0) -x <- c(1.0, 2.0, 3.5, 3.0, 4.0) -w <- rep(1.0, 5) -data <- data.frame(y = y, x = x, w = w) -df <- createDataFrame(data) -isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") -isoregFitted <- predict(isoregModel, df) -head(select(isoregFitted, "x", "y", "prediction")) -``` - -In the prediction stage, based on the fitted monotone piecewise function, the rules are: +#### Kolmogorov-Smirnov Test -* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. +`spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). +Given a `SparkDataFrame`, the test compares continuous data in a given column `testCol` with the theoretical distribution +specified by parameter `nullHypothesis`. +Users can call `summary` to get a summary of the test results. -* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. +In the following example, we test whether the `longley` dataset's `Armed_Forces` column +follows a normal distribution. We set the parameters of the normal distribution using +the mean and standard deviation of the sample. -* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. - -For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. +```{r, warning=FALSE} +df <- createDataFrame(longley) +afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) +afMean <- afStats[1] +afStd <- afStats[2] -```{r} -newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) -head(predict(isoregModel, newDF)) +test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +testSummary <- summary(test) +testSummary ``` -#### What's More? -We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in the next version 2.1.0. ### Model Persistence The following example shows how to save/load an ML model by SparkR. @@ -822,9 +926,9 @@ The main method calls of actual computation happen in the Spark JVM of the drive Two kinds of RPCs are supported in the SparkR JVM backend: method invocation and creating new objects. Method invocation can be done in two ways. -* `sparkR.invokeJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. +* `sparkR.callJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. -* `sparkR.invokeJStatic` takes a class name for static method and a list of arguments to be passed on to the method. +* `sparkR.callJStatic` takes a class name for static method and a list of arguments to be passed on to the method. The arguments are serialized using our custom wire format which is then deserialized on the JVM side. We then use Java reflection to invoke the appropriate method. diff --git a/README.md b/README.md index dd7d0e22495b..d861e9fee705 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,7 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html) -and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). +guide, on the [project web page](http://spark.apache.org/documentation.html). This README file only contains basic setup instructions. ## Building Spark @@ -29,8 +28,9 @@ To build Spark and its example programs, run: You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) -and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). + +For general development tips, including info on developing Spark using an IDE, see +[http://spark.apache.org/developer-tools.html](the Useful Developer Tools page). ## Interactive Scala Shell @@ -80,7 +80,7 @@ can be run using: ./dev/run-tests Please see the guidance on how to -[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). +[run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). ## A Note About Hadoop Versions @@ -98,7 +98,7 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Contributing +## Contributing -Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -wiki for information on how to get started contributing to the project. +Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) +for information on how to get started contributing to the project. diff --git a/assembly/pom.xml b/assembly/pom.xml index ec243eaebaea..6e092ef8928b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../pom.xml diff --git a/bin/beeline b/bin/beeline index 1627626941a7..058534699e44 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 000000000000..fa78407d4175 --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3..8a2f709960a2 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a4432..98387c2ec5b8 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c412026..4ba5399311d3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f..65d3b9612909 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 869c0b202f7f..f6157f42843e 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -50,7 +50,16 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java -if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java +if not "x%JAVA_HOME%"=="x" ( + set RUNNER=%JAVA_HOME%\bin\java +) else ( + where /q "%RUNNER%" + if ERRORLEVEL 1 ( + echo Java not found and JAVA_HOME environment variable is not set. + echo Install Java and set JAVA_HOME to point to the Java installation directory. + exit /b 1 + ) +) rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880e..421f36cac3d4 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51d..b08b944ebd31 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b..4e9d3614e637 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173..29ab10df8ab6 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c9..77a4b64e8da9 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -87,6 +87,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.mockito mockito-core diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 5b69e2bb0354..37ba543380f0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -62,8 +62,20 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; - private final MessageEncoder encoder; - private final MessageDecoder decoder; + /** + * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created + * before switching the current context class loader to ExecutorClassLoader. + * + * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the + * implementation calls "Class.forName" to check if this calls is already generated. If the + * following two objects are created in "ExecutorClassLoader.findClass", it will cause + * "ClassCircularityError". This is because loading this Netty generated class will call + * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use + * RPC to load it and cause to load the non-exist matcher class again. JVM will report + * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714) + */ + private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; + private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this(conf, rpcHandler, false); @@ -75,8 +87,6 @@ public TransportContext( boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; - this.encoder = new MessageEncoder(); - this.decoder = new MessageDecoder(); this.closeIdleConnections = closeIdleConnections; } @@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline( try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() - .addLast("encoder", encoder) + .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) - .addLast("decoder", decoder) + .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index e895f13f4545..b50e043d5c9c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -100,8 +100,10 @@ public TransportClientFactory( IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); - // TODO: Make thread pool name configurable. - this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + this.workerGroup = NettyUtils.createEventLoop( + ioMode, + conf.clientThreads(), + conf.getModuleName() + "-client"); this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); } @@ -120,7 +122,8 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) throws IOException { + public TransportClient createClient(String remoteHost, int remotePort) + throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. // Use unresolved address here to avoid DNS resolution each time we creates a client. @@ -188,13 +191,14 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) - throws IOException { + throws IOException, InterruptedException { final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); return createClient(address); } /** Create a completely new {@link TransportClient} to the remote address. */ - private TransportClient createClient(InetSocketAddress address) throws IOException { + private TransportClient createClient(InetSocketAddress address) + throws IOException, InterruptedException { logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); @@ -221,7 +225,7 @@ public void initChannel(SocketChannel ch) { // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); - if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + if (!cf.await(conf.connectionTimeoutMs())) { throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index f0956438ade2..39a7495828a8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + public static final MessageDecoder INSTANCE = new MessageDecoder(); + + private MessageDecoder() {} + @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 276f16637efc..997f74e1a21b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder { private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + public static final MessageEncoder INSTANCE = new MessageEncoder(); + + private MessageEncoder() {} + /*** * Encodes a Message by invoking its encode() method. For non-data messages, we will add one * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c33848c8406c..56782a832787 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,7 +26,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -48,7 +47,7 @@ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends SimpleChannelInboundHandler { +public class TransportChannelHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -88,14 +87,14 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { try { requestHandler.channelActive(); } catch (RuntimeException e) { - logger.error("Exception from request handler while registering channel", e); + logger.error("Exception from request handler while channel is active", e); } try { responseHandler.channelActive(); } catch (RuntimeException e) { - logger.error("Exception from response handler while registering channel", e); + logger.error("Exception from response handler while channel is active", e); } - super.channelRegistered(ctx); + super.channelActive(ctx); } @Override @@ -103,22 +102,24 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { requestHandler.channelInactive(); } catch (RuntimeException e) { - logger.error("Exception from request handler while unregistering channel", e); + logger.error("Exception from request handler while channel is inactive", e); } try { responseHandler.channelInactive(); } catch (RuntimeException e) { - logger.error("Exception from response handler while unregistering channel", e); + logger.error("Exception from response handler while channel is inactive", e); } - super.channelUnregistered(ctx); + super.channelInactive(ctx); } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); - } else { + } else if (request instanceof ResponseMessage) { responseHandler.handle((ResponseMessage) request); + } else { + ctx.fireChannelRead(request); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0d7a677820d3..047c5f3f1f09 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -89,7 +89,7 @@ private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = - NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server"); EventLoopGroup workerGroup = bossGroup; PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103ccc..fc5cc091f6e6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -73,6 +73,10 @@ private String getConfKey(String suffix) { return "spark." + module + "." + suffix; } + public String getModuleName() { + return module; + } + /** IO mode: nio or epoll */ public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 6c8dd742f4b6..bb1c40c4b0e0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -49,11 +49,11 @@ public class ProtocolSuite { private void testServerToClient(Message msg) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeInbound(serverChannel.readOutbound()); @@ -65,11 +65,11 @@ private void testServerToClient(Message msg) { private void testClientToServer(Message msg) { EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeInbound(clientChannel.readOutbound()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 959396bb8c26..9e8057438db9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -226,6 +226,8 @@ public StreamManager getStreamManager() { callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); + // make sure callback1 is called. + callback1.latch.await(60, TimeUnit.SECONDS); // failed at same time as previous assertTrue(callback1.failure instanceof IOException); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 44d16d54225e..43e63efed4a5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -100,6 +100,8 @@ public void run() { clients.add(client); } catch (IOException e) { failed.incrementAndGet(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } } }; @@ -143,7 +145,7 @@ public void reuseClientsUpToConfigVariableConcurrent() throws Exception { } @Test - public void returnDifferentClientsForDifferentServers() throws IOException { + public void returnDifferentClientsForDifferentServers() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); @@ -172,7 +174,7 @@ public void neverReturnInactiveClients() throws IOException, InterruptedExceptio } @Test - public void closeBlockClientsWithFactory() throws IOException { + public void closeBlockClientsWithFactory() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 511e1f29de36..1a2d85a2ead6 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -70,6 +70,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + log4j log4j diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 772fb88325b3..eea5cf7bec0f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -101,7 +101,7 @@ public void fetchBlocks( new RetryingBlockFetcher.BlockFetchStarter() { @Override public void createAndStart(String[] blockIds, BlockFetchingListener listener) - throws IOException { + throws IOException, InterruptedException { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); } @@ -136,7 +136,7 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) throws IOException { + ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { checkInit(); TransportClient client = clientFactory.createUnmanagedClient(host, port); try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 72bd0f803da3..5be855048e4d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -57,7 +57,8 @@ public interface BlockFetchStarter { * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection * issues. */ - void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException, InterruptedException; } /** Shared executor service used for waiting and retrying. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 42cedd994315..c6d6029d5a2e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -69,7 +69,7 @@ public void registerDriverWithShuffleService( String host, int port, long heartbeatTimeoutMs, - long heartbeatIntervalMs) throws IOException { + long heartbeatIntervalMs) throws IOException, InterruptedException { checkInit(); ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 6ba937dddb2a..81a3f065d6fc 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -103,7 +103,7 @@ public void afterEach() { } @Test - public void testGoodClient() throws IOException { + public void testGoodClient() throws IOException, InterruptedException { clientFactory = context.createClientFactory( Lists.newArrayList( new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); @@ -133,7 +133,7 @@ public void testBadClient() { } @Test - public void testNoSaslClient() throws IOException { + public void testNoSaslClient() throws IOException, InterruptedException { clientFactory = context.createClientFactory( Lists.newArrayList()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 552b5366c593..a04b6825d19d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -240,7 +240,7 @@ public void testFetchNoServer() throws Exception { } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) - throws IOException { + throws IOException, InterruptedException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index a0f69ca29a28..6ef1a2d2c89f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -59,7 +59,7 @@ public void afterEach() { } @Test - public void testValid() throws IOException { + public void testValid() throws IOException, InterruptedException { validate("my-app-id", "secret", false); } @@ -82,12 +82,13 @@ public void testBadSecret() { } @Test - public void testEncryption() throws IOException { + public void testEncryption() throws IOException, InterruptedException { validate("my-app-id", "secret", true); } /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey, boolean encrypt) throws IOException { + private void validate(String appId, String secretKey, boolean encrypt) + throws IOException, InterruptedException { ExternalShuffleClient client = new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); client.init(appId); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 91882e3b3bcd..f221544e9e5e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -66,7 +66,7 @@ public void afterEach() { } @Test - public void testNoFailures() throws IOException { + public void testNoFailures() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -85,7 +85,7 @@ public void testNoFailures() throws IOException { } @Test - public void testUnrecoverableFailure() throws IOException { + public void testUnrecoverableFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -104,7 +104,7 @@ public void testUnrecoverableFailure() throws IOException { } @Test - public void testSingleIOExceptionOnFirst() throws IOException { + public void testSingleIOExceptionOnFirst() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -127,7 +127,7 @@ public void testSingleIOExceptionOnFirst() throws IOException { } @Test - public void testSingleIOExceptionOnSecond() throws IOException { + public void testSingleIOExceptionOnSecond() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -149,7 +149,7 @@ public void testSingleIOExceptionOnSecond() throws IOException { } @Test - public void testTwoIOExceptions() throws IOException { + public void testTwoIOExceptions() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -177,7 +177,7 @@ public void testTwoIOExceptions() throws IOException { } @Test - public void testThreeIOExceptions() throws IOException { + public void testThreeIOExceptions() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -209,7 +209,7 @@ public void testThreeIOExceptions() throws IOException { } @Test - public void testRetryAndUnrecoverable() throws IOException { + public void testRetryAndUnrecoverable() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -252,7 +252,7 @@ public void testRetryAndUnrecoverable() throws IOException { @SuppressWarnings("unchecked") private static void performInteractions(List> interactions, BlockFetchingListener listener) - throws IOException { + throws IOException, InterruptedException { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 606ad1573961..fb6c24118589 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -50,6 +50,17 @@ spark-tags_${scala.binary.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.hadoop @@ -76,6 +87,9 @@ *:* + + org.scala-lang:scala-library + @@ -87,7 +101,7 @@ - + com.fasterxml.jackson ${spark.shade.packageName}.com.fasterxml.jackson diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 626f023a5b99..ff2d5c52730b 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 1c60d510e570..b9bf0342eb60 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -36,9 +36,9 @@ - org.scalatest - scalatest_${scala.binary.version} - compile + org.scala-lang + scala-library + ${scala.version} diff --git a/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/test/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/test/java/org/apache/spark/tags/DockerTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 45af98d94ef9..f8a0e577777e 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + com.twitter chill_${scala.binary.version} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 671b8c747594..ba35cf250e48 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,22 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a75..a7b0e6f80c2b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public final long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e09a6b7d93a9..b03e7182b16a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -816,6 +816,190 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } + private int getDigit(byte b) { + if (b >= '0' && b <= '9') { + return b - '0'; + } + throw new NumberFormatException(toString()); + } + + /** + * Parses this UTF8String to long. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyLong.parseLong in Hive. + */ + public long toLong() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final long stopValue = Long.MIN_VALUE / radix; + long result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we + // can just use `result > 0` to check overflow. If result overflows, we should stop and throw + // exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + /** + * Parses this UTF8String to int. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyInt.parseInt in Hive. + * + * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance + * reasons, like Hive does. + */ + public int toInt() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final int stopValue = Integer.MIN_VALUE / radix; + int result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop and + // throw exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + public short toShort() { + int intValue = toInt(); + short result = (short) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + + public byte toByte() { + int intValue = toInt(); + byte result = (byte) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 55cb094b4af4..2ecb4f1464a4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -15,6 +15,6 @@ # limitations under the License. # -spark.mesos.executor.docker.image: +spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/core/pom.xml b/core/pom.xml index eac99ab82a2e..bad3655452fb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.1.2-SNAPSHOT ../pom.xml @@ -337,6 +337,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.commons commons-crypto diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263..ea5f1a9abf69 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 1a700aa37554..3385d0e6a31c 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import javax.annotation.concurrent.GuardedBy; import java.io.IOException; +import java.nio.channels.ClosedByInterruptException; import java.util.Arrays; import java.util.BitSet; import java.util.HashSet; @@ -156,6 +157,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { break; } } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + c, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + c, e); throw new OutOfMemoryError("error while calling spill() on " + c + " : " @@ -174,6 +179,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { Utils.bytesToString(released), consumer); got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + consumer, e); + throw new RuntimeException(e.getMessage()); } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " @@ -378,14 +387,14 @@ public long cleanUpAllAllocatedMemory() { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory - logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); } } consumers.clear(); for (MemoryBlock page : pageTable) { if (page != null) { - logger.warn("leak a page: " + page + " in task " + taskAttemptId); + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b..2fde5c300f07 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -320,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -337,7 +340,11 @@ private long[] mergeSpillsWithFileStream( final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +352,35 @@ private long[] mergeSpillsWithFileStream( spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; + ByteStreams.copy(partitionInputStream, partitionOutput); } finally { - Closeables.close(partitionInputStream, innerThrewException); + partitionInputStream.close(); } } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { @@ -414,17 +422,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( - spillInputChannelPositions[i], - bytesToTransfer, - mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; - } + Utils.copyFileStreamNIO( + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d2fcdea4f2ce..4bef21b6b4e4 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -170,6 +170,8 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; + private final int initialCapacity; + private final BlockManager blockManager; private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; @@ -202,6 +204,7 @@ public BytesToBytesMap( throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } + this.initialCapacity = initialCapacity; allocate(initialCapacity); } @@ -695,7 +698,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff if (numKeys == MAX_CAPACITY // The map could be reused from last spill (because of no enough memory to grow), // then we don't try to grow again if hit the `growthThreshold`. - || !canGrowArray && numKeys > growthThreshold) { + || !canGrowArray && numKeys >= growthThreshold) { return false; } @@ -739,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff longArray.set(pos * 2 + 1, keyHashcode); isDefined = true; - if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) { + if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) { try { growAndRehash(); } catch (OutOfMemoryError oom) { @@ -902,12 +905,13 @@ public LongArray getArray() { public void reset() { numKeys = 0; numValues = 0; - longArray.zeroOut(); - + freeArray(longArray); while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); } + allocate(initialCapacity); + canGrowArray = true; currentPage = null; pageCursor = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55..3dd318471008 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafa..5b42843717e9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -22,6 +22,8 @@ import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -253,6 +255,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -283,6 +286,14 @@ public boolean hasNext() { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); @@ -322,7 +333,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index a658e5eb47b7..b6323c624b7b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,6 +23,8 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -94,6 +97,14 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 1df67337ea03..fe5db6aa26b6 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -411,10 +411,6 @@ $(document).ready(function () { } ], "columnDefs": [ - { - "targets": [ 15 ], - "visible": logsExist(response) - }, { "targets": [ 16 ], "visible": getThreadDumpEnabled() @@ -423,7 +419,8 @@ $(document).ready(function () { "order": [[0, "asc"]] }; - $(selector).DataTable(conf); + var dt = $(selector).DataTable(conf); + dt.column(15).visible(logsExist(response)); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 000000000000..55d540d8317a --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a7125..c2afa993b2f2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -39,7 +39,7 @@ Started - + Completed @@ -68,16 +68,16 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} - {{attemptId}} + {{attemptId}} {{startTime}} - {{endTime}} + {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2a32e18672a2..a4430347a0b2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -114,12 +114,19 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } } - var data = {"applications": array} + var data = { + "uiroot": uiRoot, + "applications": array + } + $.get("static/historypage-template.html", function(template) { historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); var selector = "#history-summary-table"; @@ -135,6 +142,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; @@ -161,6 +171,13 @@ $(document).ready(function() { } } + if (requestedIncomplete) { + var completedCells = document.getElementsByClassName("completedColumn"); + for (i = 0; i < completedCells.length; i++) { + completedCells[i].style.display='none'; + } + } + var durationCells = document.getElementsByClassName("durationClass"); for (i = 0; i < durationCells.length; i++) { var timeInMilliseconds = parseInt(durationCells[i].title); diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c0..b5c43e5788bc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index ff241470f32d..9960d5c34d1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -207,8 +207,8 @@ sorttable = { hasInputs = (typeof node.getElementsByTagName == 'function') && node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { + + if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) { return node.getAttribute("sorttable_customkey"); } else if (typeof node.textContent != 'undefined' && !hasInputs) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860e..0315ebf5c48a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b157f3e0a407..319a719efaa7 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -246,4 +246,8 @@ a.expandbutton { text-align: center; margin: 0; padding: 4px 0; +} + +.table-cell-width-limited td { + max-width: 600px; } \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e37307aa1f70..0fa1fcf25f8b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var uiRoot = ""; + +function setUIRoot(val) { + uiRoot = val; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 9d1f1d59dbce..7bea636c94aa 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -26,7 +26,7 @@ package org.apache.spark * * An accumulator is created from an initial value `v` by calling * [[SparkContext#accumulator SparkContext.accumulator]]. - * Tasks running on the cluster can then add to it using the [[Accumulable#+= +=]] operator. + * Tasks running on the cluster can then add to it using the `+=` operator. * However, they cannot read its value. Only the driver program can read the accumulator's value, * using its [[#value]] method. * diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e7..4d884dec0791 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,7 +18,8 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} +import java.util.Collections +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} import scala.collection.JavaConverters._ @@ -58,7 +59,12 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() + /** + * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they + * have not been handled by the reference queue. + */ + private val referenceBuffer = + Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) private val referenceQueue = new ReferenceQueue[AnyRef] @@ -139,7 +145,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } @@ -176,10 +182,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { .map(_.asInstanceOf[CleanupTaskWeakReference]) // Synchronize here to avoid being interrupted on stop() synchronized { - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer.remove(reference.get) - task match { + reference.foreach { ref => + logDebug("Got cleaning task " + ref.task) + referenceBuffer.remove(ref) + ref.task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) case CleanShuffle(shuffleId) => diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618..f054a78d5788 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f8f0f513134..6f5c31d7ab71 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -322,7 +322,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, if (minSizeForBroadcast > maxRpcMessageSize) { val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + - "message that is to large." + "message that is too large." logError(msg) throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed6..f83f5278e8b8 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index be19179b00a4..5f14102c3c36 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -150,8 +150,8 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. * * You can optionally specify the default configuration. If you do, for each setting which is * missing in SparkConf, the corresponding setting is used from the default configuration. diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 199365ad925a..6e12cd16a3a5 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,7 +21,6 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate -import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes @@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -185,7 +183,9 @@ import org.apache.spark.util.Utils * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) +private[spark] class SecurityManager( + sparkConf: SparkConf, + val ioEncryptionKey: Option[Array[Byte]] = None) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing acls enabled to: " + aclsOn) } + def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey + /** * Generates or looks up the secret key. * @@ -559,19 +561,4 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) - val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) - keyGen.init(keyLen) - - val ioKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) - } - } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c9..d78b9f1b2968 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { @@ -262,7 +262,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getTimeAsSeconds(key: String): Long = { Utils.timeStringAsSeconds(get(key)) @@ -279,7 +279,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) @@ -296,7 +296,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsBytes(key: String): Long = { Utils.byteStringAsBytes(get(key)) @@ -320,7 +320,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsKb(key: String): Long = { Utils.byteStringAsKb(get(key)) @@ -337,7 +337,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsMb(key: String): Long = { Utils.byteStringAsMb(get(key)) @@ -354,7 +354,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException */ def getSizeAsGb(key: String): Long = { Utils.byteStringAsGb(get(key)) @@ -378,7 +378,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray } - /** Get all parameters that start with `prefix` */ + /** + * Get all parameters that start with `prefix` + */ def getAllWithPrefix(prefix: String): Array[(String, String)] = { getAll.filter { case (k, v) => k.startsWith(prefix) } .map { case (k, v) => (k.substring(prefix.length), v) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd..cf03ae545635 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.{MalformedURLException, URI} +import java.net.{URI} import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") + warnDeprecatedVersions() + /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -279,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -346,6 +348,16 @@ class SparkContext(config: SparkConf) extends Logging { value } + private def warnDeprecatedVersions(): Unit = { + val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) + if (javaVersion.length >= 2 && javaVersion(1).toInt == 7) { + logWarning("Support for Java 7 is deprecated as of Spark 2.0.0") + } + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { + logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") + } + } + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN @@ -410,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging { } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") - } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. @@ -633,7 +641,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).orNull @@ -651,7 +659,7 @@ class SparkContext(config: SparkConf) extends Logging { * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * The application can also use `org.apache.spark.SparkContext.cancelJobGroup` to cancel all * running jobs in this group. For example, * {{{ * // In the main thread: @@ -688,7 +696,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -915,7 +923,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -936,12 +944,11 @@ class SparkContext(config: SparkConf) extends Logging { classOf[LongWritable], classOf[BytesWritable], conf = conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() assert(bytes.length == recordLength, "Byte array does not have correct length") bytes } - data } /** @@ -958,7 +965,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -983,7 +990,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1022,7 +1029,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1046,7 +1053,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1072,7 +1079,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1112,7 +1119,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1138,7 +1145,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1157,7 +1164,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1187,7 +1194,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1268,7 +1275,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1297,7 +1304,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1318,19 +1325,21 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** @@ -1370,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T]: CollectionAccumulator[T] = { @@ -1380,7 +1389,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T](name: String): CollectionAccumulator[T] = { @@ -1538,7 +1547,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1560,7 +1569,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1578,7 +1587,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1627,7 +1636,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -1716,29 +1726,20 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null + try { + val file = new File(uri.getPath) + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") } + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) + null } // A JAR file which exists locally on every worker node case "local" => @@ -1762,8 +1763,31 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq - // Shut down the SparkContext. - def stop() { + /** + * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark + * may wait for some internal threads to finish. It's better to use this method to stop + * SparkContext instead. + */ + private[spark] def stopInNewThread(): Unit = { + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + try { + SparkContext.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") @@ -1826,6 +1850,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() @@ -2027,7 +2054,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * Cancel active jobs for the specified group. See `org.apache.spark.SparkContext.setJobGroup` * for more information. */ def cancelJobGroup(groupId: String) { @@ -2045,7 +2072,7 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel a given job if it's scheduled or running. * * @param jobId the job ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelJob(jobId: Int) { dagScheduler.cancelJob(jobId) @@ -2055,7 +2082,7 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel a given stage and all jobs associated with it. * * @param stageId the stage ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @note Throws `InterruptedException` if the cancel message cannot be sent */ def cancelStage(stageId: Int) { dagScheduler.cancelStage(stageId) @@ -2285,7 +2312,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2310,7 +2337,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { @@ -2322,6 +2349,13 @@ object SparkContext extends Logging { } } + /** Return the current active [[SparkContext]] if any. */ + private[spark] def getActive: Option[SparkContext] = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + Option(activeContext.get()) + } + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -2550,8 +2584,8 @@ object SparkContext extends Logging { val serviceLoaders = ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url)) if (serviceLoaders.size > 1) { - throw new SparkException(s"Multiple Cluster Managers ($serviceLoaders) registered " + - s"for the url $url:") + throw new SparkException( + s"Multiple external cluster managers registered for the url $url: $serviceLoaders") } serviceLoaders.headOption } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1ffeb129880f..1296386ac9bd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ @@ -165,15 +166,20 @@ object SparkEnv extends Logging { val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } create( conf, SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, port, - isDriver = true, - isLocal = isLocal, - numUsableCores = numCores, + isLocal, + numCores, + ioEncryptionKey, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -189,6 +195,7 @@ object SparkEnv extends Logging { hostname: String, port: Int, numCores: Int, + ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { val env = create( conf, @@ -196,9 +203,9 @@ object SparkEnv extends Logging { hostname, hostname, port, - isDriver = false, - isLocal = isLocal, - numUsableCores = numCores + isLocal, + numCores, + ioEncryptionKey ) SparkEnv.set(env) env @@ -213,18 +220,26 @@ object SparkEnv extends Logging { bindAddress: String, advertiseAddress: String, port: Int, - isDriver: Boolean, isLocal: Boolean, numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - val securityManager = new SecurityManager(conf) + val securityManager = new SecurityManager(conf, ioEncryptionKey) + ioEncryptionKey.foreach { _ => + if (!securityManager.isSaslEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, @@ -270,7 +285,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val serializerManager = new SerializerManager(serializer, conf) + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc86..7f75a393bf8f 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -162,7 +162,7 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private[spark] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 52c4656c271b..22a553e68439 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -112,7 +112,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getExecutorInfos: Array[SparkExecutorInfo] = { val executorIdToRunningTasks: Map[String, Int] = - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors() + sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors sc.getExecutorStorageStatus.map { status => val bmId = status.blockManagerId diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 27abccf5ac2a..0fd777ed1282 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -164,7 +164,7 @@ abstract class TaskContext extends Serializable { /** * Get a local property set upstream in the driver, or null if it is missing. See also - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String @@ -174,7 +174,7 @@ abstract class TaskContext extends Serializable { /** * ::DeveloperApi:: * Returns all metrics sources with the given name which are associated with the instance - * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + * which runs the task. For more information see `org.apache.spark.metrics.MetricsSystem`. */ @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7ca3c103dbf5..7745387dbceb 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -65,7 +65,7 @@ sealed trait TaskFailedReason extends TaskEndReason { /** * :: DeveloperApi :: - * A [[org.apache.spark.scheduler.ShuffleMapTask]] that completed successfully earlier, but we + * A `org.apache.spark.scheduler.ShuffleMapTask` that completed successfully earlier, but we * lost the executor before the stage completed. This means Spark needs to reschedule the task * to be re-executed on a different executor. */ diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 871b9d1ad575..5cdc4eeeccbc 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -18,19 +18,23 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.net.{URI, URL} +import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.nio.file.Paths +import java.security.SecureRandom +import java.security.cert.X509Certificate import java.util.Arrays import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl._ +import javax.servlet.http.HttpServletResponse +import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.io.{ByteStreams, Files} -import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -182,11 +186,60 @@ private[spark] object TestUtils { assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + /** + * Returns the response code and url (if redirected) from an HTTP(S) URL. + */ + def httpResponseCodeAndURL( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): (Int, Option[String]) = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } + + // Disable cert and host name validation for HTTPS tests. + if (connection.isInstanceOf[HttpsURLConnection]) { + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers(): Array[X509Certificate] = null + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) + connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + connection.setInstanceFollowRedirects(false) + } + + try { + connection.connect() + if (connection.getResponseCode == HttpServletResponse.SC_FOUND) { + (connection.getResponseCode, Option(connection.getHeaderField("Location"))) + } else { + (connection.getResponseCode(), None) + } + } finally { + connection.disconnect() + } + } + + /** + * Returns the response code from an HTTP(S) URL. + */ + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { + httpResponseCodeAndURL(url, method, headers)._1 + } } /** - * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + * A `SparkListener` that detects whether spills have occurred in Spark jobs. */ private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad51..b71af0d42cdb 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -45,7 +45,9 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) import JavaDoubleRDD.fromRDD - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) /** @@ -153,7 +155,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +258,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaa..766aea213a97 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -54,7 +54,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) /** @@ -206,7 +208,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +225,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +236,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +260,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +270,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,8 +406,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = @@ -409,8 +417,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = @@ -448,13 +456,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.subtractByKey(other)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], p: Partitioner): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, p)) @@ -539,8 +551,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(): JavaPairRDD[K, JIterable[V]] = diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7..41b5cab601c3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -34,7 +34,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) /** @@ -98,24 +100,32 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -153,7 +163,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) @@ -161,7 +171,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf21..eda16d957cc5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd5..9481156bc93a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -238,7 +238,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}} * * then `rdd` contains * {{{ @@ -270,7 +272,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}}, * * then `rdd` contains * {{{ @@ -298,7 +302,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +320,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +370,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +400,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +420,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +441,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +462,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +491,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +698,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -749,7 +753,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + * `org.apache.spark.api.java.JavaSparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) @@ -769,7 +773,7 @@ class JavaSparkContext(val sc: SparkContext) * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * The application can also use `org.apache.spark.api.java.JavaSparkContext.cancelJobGroup` * to cancel all running jobs in this group. For example, * {{{ * // In the main thread: @@ -802,7 +806,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Cancel active jobs for the specified group. See - * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + * `org.apache.spark.api.java.JavaSparkContext.setJobGroup` for more information. */ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) @@ -811,7 +815,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced..6aa290ecd7bb 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala new file mode 100644 index 000000000000..3432700f1160 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentHashMap + +/** JVM object ID wrapper */ +private[r] case class JVMObjectId(id: String) { + require(id != null, "Object ID cannot be null.") +} + +/** + * Counter that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] class JVMObjectTracker { + + private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]() + private[this] val objCounter = new AtomicInteger() + + /** + * Returns the JVM object associated with the input key or None if not found. + */ + final def get(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.get(id)) + } else { + None + } + } + + /** + * Returns the JVM object associated with the input key or throws an exception if not found. + */ + @throws[NoSuchElementException]("if key does not exist.") + final def apply(id: JVMObjectId): Object = { + get(id).getOrElse( + throw new NoSuchElementException(s"$id does not exist.") + ) + } + + /** + * Adds a JVM object to track and returns assigned ID, which is unique within this tracker. + */ + final def addAndGetId(obj: Object): JVMObjectId = { + val id = JVMObjectId(objCounter.getAndIncrement().toString) + objMap.put(id, obj) + id + } + + /** + * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. + */ + final def remove(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.remove(id)) + } else { + None + } + } + + /** + * Number of JVM objects being tracked. + */ + final def size: Int = objMap.size() + + /** + * Clears the tracker. + */ + final def clear(): Unit = objMap.clear() +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 550746c552d0..2d1152a03644 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel @@ -42,6 +42,9 @@ private[spark] class RBackend { private[this] var bootstrap: ServerBootstrap = null private[this] var bossGroup: EventLoopGroup = null + /** Tracks JVM objects returned to R for this RBackend instance. */ + private[r] val jvmObjectTracker = new JVMObjectTracker + def init(): Int = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( @@ -94,6 +97,7 @@ private[spark] class RBackend { bootstrap.childGroup().shutdownGracefully() } bootstrap = null + jvmObjectTracker.clear() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 9f5afa29d6d2..cfd37ac54ba2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,7 +20,6 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.TimeUnit -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} @@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend) assert(numArgs == 1) writeInt(dos, 0) - writeObject(dos, args(0)) + writeObject(dos, args(0), server.jvmObjectTracker) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + server.jvmObjectTracker.remove(JVMObjectId(objToRemove)) writeInt(dos, 0) - writeObject(dos, null) + writeObject(dos, null, server.jvmObjectTracker) } catch { case e: Exception => logError(s"Removing $objId failed", e) @@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } + obj = server.jvmObjectTracker(JVMObjectId(objId)) + obj.getClass } val args = readArgs(numArgs, dis) @@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) // Write status bit writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker) } else if (methodName == "") { // methodName should be "" for constructor val ctors = cls.getConstructors @@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend) val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker) } else { throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) } @@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { _ => - readObject(dis) + readObject(dis, server.jvmObjectTracker) }.toArray } @@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend) } } -/** - * Helper singleton that tracks JVM objects returned to R. - * This is useful for referencing these objects in RPC calls. - */ -private[r] object JVMObjectTracker { - - // TODO: This map should be thread-safe if we want to support multiple - // connections at the same time - private[this] val objMap = new HashMap[String, Object] - - // TODO: We support only one connection now, so an integer is fine. - // Investigate using use atomic integer in the future. - private[this] var objCounter: Int = 0 - - def getObject(id: String): Object = { - objMap(id) - } - - def get(id: String): Option[Object] = { - objMap.get(id) - } - - def put(obj: Object): String = { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - def remove(id: String): Option[Object] = { - objMap.remove(id) - } - -} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index a1a5eb8cf55e..72ae0340aa3d 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.r +import java.io.File import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -127,6 +128,14 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.toString, value.toString) } + if (sparkEnvirMap.containsKey("spark.r.sql.derby.temp.dir") && + System.getProperty("derby.stream.error.file") == null) { + // This must be set before SparkContext is instantiated. + System.setProperty("derby.stream.error.file", + Seq(sparkEnvirMap.get("spark.r.sql.derby.temp.dir").toString, "derby.log") + .mkString(File.separator)) + } + val jsc = new JavaSparkContext(sparkConf) jars.foreach { jar => jsc.addJar(jar) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 7ef64723d959..88118392003e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -152,7 +152,7 @@ private[spark] class RRunner[U]( dataOut.writeInt(mode) if (isDataFrame) { - SerDe.writeObject(dataOut, colNames) + SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) } if (!iter.hasNext) { @@ -347,6 +347,8 @@ private[r] object RRunner { pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) + pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) + pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 77825e75e513..fdd8cf62f0e5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -84,7 +84,6 @@ private[spark] object RUtils { } } else { // Otherwise, assume the package is local - // TODO: support this for Mesos val sparkRPkgPath = localSparkRPackagePath.getOrElse { throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 550e075a9512..dad928cdcfd0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { - type ReadObject = (DataInputStream, Char) => Object - type WriteObject = (DataOutputStream, Object) => Boolean + type SQLReadObject = (DataInputStream, Char) => Object + type SQLWriteObject = (DataOutputStream, Object) => Boolean - var sqlSerDe: (ReadObject, WriteObject) = _ + private[this] var sqlReadObject: SQLReadObject = _ + private[this] var sqlWriteObject: SQLWriteObject = _ - def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { - this.sqlSerDe = sqlSerDe + def setSQLReadObject(value: SQLReadObject): this.type = { + sqlReadObject = value + this + } + + def setSQLWriteObject(value: SQLWriteObject): this.type = { + sqlWriteObject = value + this } // Type mapping from R to Java @@ -56,32 +63,33 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream): Object = { + def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Object = { val dataType = readObjectType(dis) - readTypedObject(dis, dataType) + readTypedObject(dis, dataType, jvmObjectTracker) } def readTypedObject( dis: DataInputStream, - dataType: Char): Object = { + dataType: Char, + jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) case 'c' => readString(dis) - case 'e' => readMap(dis) + case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) - case 'a' => readArray(dis) - case 'l' => readList(dis) + case 'a' => readArray(dis, jvmObjectTracker) + case 'l' => readList(dis, jvmObjectTracker) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => jvmObjectTracker(JVMObjectId(readString(dis))) case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { - val obj = (sqlSerDe._1)(dis, dataType) + val obj = sqlReadObject(dis, dataType) if (obj == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { @@ -181,28 +189,28 @@ private[spark] object SerDe { } // All elements of an array must be of the same type - def readArray(dis: DataInputStream): Array[_] = { + def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) case 'c' => readStringArr(dis) case 'd' => readDoubleArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x))) case 'r' => readBytesArr(dis) case 'a' => val len = readInt(dis) - (0 until len).map(_ => readArray(dis)).toArray + (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray case 'l' => val len = readInt(dis) - (0 until len).map(_ => readList(dis)).toArray + (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { val len = readInt(dis) (0 until len).map { _ => - val obj = (sqlSerDe._1)(dis, arrType) + val obj = sqlReadObject(dis, arrType) if (obj == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { @@ -215,17 +223,19 @@ private[spark] object SerDe { // Each element of a list can be of different type. They are all represented // as Object on JVM side - def readList(dis: DataInputStream): Array[Object] = { + def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[Object] = { val len = readInt(dis) - (0 until len).map(_ => readObject(dis)).toArray + (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + def readMap( + in: DataInputStream, + jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { // Keys is an array of String - val keys = readArray(in).asInstanceOf[Array[Object]] - val values = readList(in) + val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]] + val values = readList(in, jvmObjectTracker) keys.zip(values).toMap.asJava } else { @@ -272,7 +282,11 @@ private[spark] object SerDe { } } - private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + private def writeKeyValue( + dos: DataOutputStream, + key: Object, + value: Object, + jvmObjectTracker: JVMObjectTracker): Unit = { if (key == null) { throw new IllegalArgumentException("Key in map can't be null.") } else if (!key.isInstanceOf[String]) { @@ -280,10 +294,10 @@ private[spark] object SerDe { } writeString(dos, key.asInstanceOf[String]) - writeObject(dos, value) + writeObject(dos, value, jvmObjectTracker) } - def writeObject(dos: DataOutputStream, obj: Object): Unit = { + def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: JVMObjectTracker): Unit = { if (obj == null) { writeType(dos, "void") } else { @@ -373,14 +387,14 @@ private[spark] object SerDe { case v: Array[Object] => writeType(dos, "list") writeInt(dos, v.length) - v.foreach(elem => writeObject(dos, elem)) + v.foreach(elem => writeObject(dos, elem, jvmObjectTracker)) // Handle Properties // This must be above the case java.util.Map below. // (Properties implements Map and will be serialized as map otherwise) case v: java.util.Properties => writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) // Handle map case v: java.util.Map[_, _] => @@ -392,19 +406,21 @@ private[spark] object SerDe { val key = entry.getKey val value = entry.getValue - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + writeKeyValue( + dos, key.asInstanceOf[Object], value.asInstanceOf[Object], jvmObjectTracker) } case v: scala.collection.Map[_, _] => writeType(dos, "map") writeInt(dos, v.size) - v.foreach { case (key, value) => - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + v.foreach { case (k1, v1) => + writeKeyValue(dos, k1.asInstanceOf[Object], v1.asInstanceOf[Object], jvmObjectTracker) } case _ => - if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + val sqlWriteSucceeded = sqlWriteObject != null && sqlWriteObject(dos, value) + if (!sqlWriteSucceeded) { writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) } } } @@ -447,9 +463,9 @@ private[spark] object SerDe { out.write(value) } - def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) - writeString(out, objId) + def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: JVMObjectTracker): Unit = { + val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value) + writeString(out, id) } def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e8d6d587b482..22d01c47e645 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer +import java.util.zip.Adler32 import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 + checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) } setConf(SparkEnv.get.conf) @@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks(obj) + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + override protected def getValue() = { _value } + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + /** * Divide the object into multiple blocks and put those blocks in the block manager. * @@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { @@ -135,6 +160,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => bm.getRemoteBytes(pieceId) match { case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } // We found the block from remote executors/driver's BlockManager, so put the block // in this executor's BlockManager. if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { @@ -175,11 +207,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId).map(_.data.next()) match { - case Some(x) => - releaseLock(broadcastId) - x.asInstanceOf[T] - + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") + } case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b7113..a4de3d7eaf45 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac3..1b3e9521c9bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { @@ -373,7 +373,7 @@ class SparkHadoopUtil extends Logging { } } catch { case e: IOException => - logDebug("Failed to decode $token: $e", e) + logDebug(s"Failed to decode $token: $e", e) } buffer.toString } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5c052286099f..443f1f5b084c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -322,7 +322,7 @@ object SparkSubmit { } // Require all R files to be local - if (args.isR && !isYarnCluster) { + if (args.isR && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } @@ -330,9 +330,6 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -410,9 +407,9 @@ object SparkSubmit { printErrorAndExit("Distributing R packages with standalone cluster is not supported.") } - // TODO: Support SparkR with mesos cluster - if (args.isR && clusterManager == MESOS) { - printErrorAndExit("SparkR is not supported for Mesos cluster.") + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -488,12 +485,17 @@ object SparkSubmit { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } @@ -598,6 +600,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") } else { childArgs += (args.primaryResource, args.mainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff83646..6d8758a3d3b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + 0 + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + 0 + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c81..2bd019c53aae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -97,6 +97,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) + private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) + private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") + private val HISTORY_UI_ADMIN_ACLS_GROUPS = conf.get("spark.history.ui.admin.acls.groups", "") + logInfo(s"History server ui acls " + (if (HISTORY_UI_ACLS_ENABLE) "enabled" else "disabled") + + "; users with admin permissions: " + HISTORY_UI_ADMIN_ACLS.toString + + "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) @@ -108,7 +115,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +127,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +235,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -244,13 +257,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) + ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) + val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") + ui.getSecurityManager.setAdminAcls(adminAcls) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) + val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + appListener.adminAclsGroups.getOrElse("") + ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) } else { @@ -329,26 +343,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +396,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") @@ -640,9 +671,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) false } - // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) + /* true to check only for Active NNs status */ + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET, true) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14..af1471763340 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -26,17 +26,35 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + // stripXSS is called first to remove suspicious characters used in XSS attacks val requestedIncomplete = - Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean + Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +64,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56..b02992af7b04 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } @@ -261,7 +269,7 @@ object HistoryServer extends Logging { Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() - val securityManager = new SecurityManager(conf) + val securityManager = createSecurityManager(conf) val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) @@ -281,6 +289,24 @@ object HistoryServer extends Logging { while(true) { Thread.sleep(Int.MaxValue) } } + /** + * Create a security manager. + * This turns off security in the SecurityManager, so that the History Server can start + * in a Spark cluster where security is enabled. + * @param config configuration for the SecurityManager constructor + * @return the security manager for use in constructing the History Server. + */ + private[history] def createSecurityManager(config: SparkConf): SecurityManager = { + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + + new SecurityManager(config) + } + def initSecurity() { // If we are accessing HDFS and it has security enabled (Kerberos), we have to login // from a keytab file so that we can access HDFS beyond the kerberos ticket expiration. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 2eddb5ff5447..080ba12c2f0d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Command-line parser for the master. + * Command-line parser for the [[HistoryServer]]. */ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c..4618e6117a4f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 18cff3125d6b..57778701ff9e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val appId = UIUtils.stripXSS(request.getParameter("appId")) val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 3fb860582cc1..d26e99cda31a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val id = Option(request.getParameter("id")) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val killFlag = + Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean + val id = Option(UIUtils.stripXSS(request.getParameter("id"))) if (id.isDefined && killFlag) { action(id.get) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e878c10183f6..58a181128eb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -57,7 +57,8 @@ private[deploy] class DriverRunner( @volatile private[worker] var finalException: Option[Exception] = None // Timeout to wait for when trying to terminate a driver. - private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000 + private val DRIVER_TERMINATE_TIMEOUT_MS = + conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s") // Decoupled for testing def setClock(_clock: Clock): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a96..0940f3c55844 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -187,8 +187,7 @@ private[deploy] class Worker( webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() - val scheme = if (webUi.sslOptions.enabled) "https" else "http" - workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}" + workerWebUiUrl = s"http://$publicAddress:${webUi.boundPort}" registerWithMaster() metricsSystem.registerSource(workerSource) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 465c214362b2..63c0d7d418b6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -35,13 +35,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 + // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => @@ -57,13 +60,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7eec4ae64f29..92a27902c669 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ - Seq[(String, String)](("spark.app.id", appId)) + val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, isLocal = false) + driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9501dd9cd8e9..3346f6dd1f97 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -84,6 +84,16 @@ private[spark] class Executor( // Start worker thread pool private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") private val executorSource = new ExecutorSource(threadPool, executorId) + // Pool used for threads that supervise task killing / cancellation + private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") + // For tasks which are in the process of being killed, this map holds the most recently created + // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't + // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding + // the integrity of the map's internal state). The purpose of this map is to prevent the creation + // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to + // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise + // create. The map key is a task id. + private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -93,6 +103,9 @@ private[spark] class Executor( // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) + // Whether to monitor killed / interrupted tasks + private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -148,9 +161,27 @@ private[spark] class Executor( } def killTask(taskId: Long, interruptThread: Boolean): Unit = { - val tr = runningTasks.get(taskId) - if (tr != null) { - tr.kill(interruptThread) + val taskRunner = runningTasks.get(taskId) + if (taskRunner != null) { + if (taskReaperEnabled) { + val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized { + val shouldCreateReaper = taskReaperForTask.get(taskId) match { + case None => true + case Some(existingReaper) => interruptThread && !existingReaper.interruptThread + } + if (shouldCreateReaper) { + val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + taskReaperForTask(taskId) = taskReaper + Some(taskReaper) + } else { + None + } + } + // Execute the TaskReaper from outside of the synchronized block. + maybeNewTaskReaper.foreach(taskReaperPool.execute) + } else { + taskRunner.kill(interruptThread = interruptThread) + } } } @@ -161,12 +192,7 @@ private[spark] class Executor( * @param interruptThread whether to interrupt the task thread */ def killAllTasks(interruptThread: Boolean) : Unit = { - // kill all the running tasks - for (taskRunner <- runningTasks.values().asScala) { - if (taskRunner != null) { - taskRunner.kill(interruptThread) - } - } + runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) } def stop(): Unit = { @@ -192,13 +218,21 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + val threadName = s"Executor task launch worker for task $taskId" + /** Whether this task has been killed. */ @volatile private var killed = false + @volatile private var threadId: Long = -1 + + def getThreadId: Long = threadId + /** Whether this task has been finished. */ @GuardedBy("TaskRunner.this") private var finished = false + def isFinished: Boolean = synchronized { finished } + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ @@ -229,9 +263,15 @@ private[spark] class Executor( // ClosedByInterruptException during execBackend.statusUpdate which causes // Executor to crash Thread.interrupted() + // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there + // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) + // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: + notifyAll() } override def run(): Unit = { + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() @@ -431,6 +471,117 @@ private[spark] class Executor( } } + /** + * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally + * sending a Thread.interrupt(), and monitoring the task until it finishes. + * + * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks + * may not be interruptable or may not respond to their "killed" flags being set. If a significant + * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but + * remain running then this can lead to a situation where new jobs and tasks are starved of + * resources that are being used by these zombie tasks. + * + * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie + * tasks. For backwards-compatibility / backportability this component is disabled by default + * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`. + * + * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically + * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers + * in case kill is called twice with different values for the `interrupt` parameter. + * + * Once created, a TaskReaper will run until its supervised task has finished running. If the + * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if + * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely + * if the supervised task never exits. + */ + private class TaskReaper( + taskRunner: TaskRunner, + val interruptThread: Boolean) + extends Runnable { + + private[this] val taskId: Long = taskRunner.taskId + + private[this] val killPollingIntervalMs: Long = + conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s") + + private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1") + + private[this] val takeThreadDump: Boolean = + conf.getBoolean("spark.task.reaper.threadDump", true) + + override def run(): Unit = { + val startTimeMs = System.currentTimeMillis() + def elapsedTimeMs = System.currentTimeMillis() - startTimeMs + def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs + try { + // Only attempt to kill the task once. If interruptThread = false then a second kill + // attempt would be a no-op and if interruptThread = true then it may not be safe or + // effective to interrupt multiple times: + taskRunner.kill(interruptThread = interruptThread) + // Monitor the killed task until it exits. The synchronization logic here is complicated + // because we don't want to synchronize on the taskRunner while possibly taking a thread + // dump, but we also need to be careful to avoid races between checking whether the task + // has finished and wait()ing for it to finish. + var finished: Boolean = false + while (!finished && !timeoutExceeded()) { + taskRunner.synchronized { + // We need to synchronize on the TaskRunner while checking whether the task has + // finished in order to avoid a race where the task is marked as finished right after + // we check and before we call wait(). + if (taskRunner.isFinished) { + finished = true + } else { + taskRunner.wait(killPollingIntervalMs) + } + } + if (taskRunner.isFinished) { + finished = true + } else { + logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms") + if (takeThreadDump) { + try { + Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread => + if (thread.threadName == taskRunner.threadName) { + logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}") + } + } + } catch { + case NonFatal(e) => + logWarning("Exception thrown while obtaining thread dump: ", e) + } + } + } + } + + if (!taskRunner.isFinished && timeoutExceeded()) { + if (isLocal) { + logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " + + "not killing JVM because we are running in local mode.") + } else { + // In non-local-mode, the exception thrown here will bubble up to the uncaught exception + // handler and cause the executor JVM to exit. + throw new SparkException( + s"Killing executor JVM because killed task $taskId could not be stopped within " + + s"$killTimeoutMs ms.") + } + } + } finally { + // Clean up entries in the taskReaperForTask map. + taskReaperForTask.synchronized { + taskReaperForTask.get(taskId).foreach { taskReaperInMap => + if (taskReaperInMap eq this) { + taskReaperForTask.remove(taskId) + } else { + // This must have been a TaskReaper where interruptThread == false where a subsequent + // killTask() call for the same task had interruptThread == true and overwrote the + // map entry. + } + } + } + } + } + } + /** * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f66510b6f977..59404e08895a 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,6 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} +import org.apache.spark.internal.config +import org.apache.spark.SparkContext + /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -40,9 +43,14 @@ private[spark] abstract class StreamFileInputFormat[T] * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API * which is set through setMaxSplitSize */ - def setMinPartitions(context: JobContext, minPartitions: Int) { - val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum - val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong + def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { + val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultParallelism = sc.defaultParallelism + val files = listStatus(context).asScala + val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum + val bytesPerCore = totalBytes / defaultParallelism + val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 497ca92c7bc6..f4844dee62ef 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -198,12 +198,26 @@ package object config { .createWithDefault(0) private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port") - .doc("Port to use for the block managed on the driver.") + .doc("Port to use for the block manager on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") + "encountering corrupted or non-existing files and contents that have been read will still " + + "be returned.") .booleanConf .createWithDefault(false) + + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) + + private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes") + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimate, then the partitions with small files will be faster than partitions with" + + " bigger files.") + .longConf + .createWithDefault(4 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 000000000000..afd2250c93a8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 000000000000..c99b75e52325 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.{Date, UUID} + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.SparkHadoopWriter +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + context.getOutputFormatClass.newInstance().getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + val filename = getFilename(taskContext, ext) + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + new TaskCommitMessage(addedAbsPathFiles.toMap) + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef75..2e991ce394c4 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -173,7 +172,7 @@ private final object SnappyCompressionCodec { } /** - * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. */ diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 3f7cfd9d2c11..99ec78633ab7 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -85,6 +85,17 @@ object HiveCatalogMetrics extends Source { */ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + /** + * Tracks the total number of Hive client calls (e.g. to lookup a table). + */ + val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls")) + + /** + * Tracks the total number of Spark jobs launched for parallel file listing. + */ + val METRIC_PARALLEL_LISTING_JOB_COUNT = metricRegistry.counter( + MetricRegistry.name("parallelListingJobCount")) + /** * Resets the values of all metrics to zero. This is useful in tests. */ @@ -92,10 +103,14 @@ object HiveCatalogMetrics extends Source { METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount()) + METRIC_PARALLEL_LISTING_JOB_COUNT.dec(METRIC_PARALLEL_LISTING_JOB_COUNT.getCount()) } // clients can use these to avoid classloader issues with the codahale classes def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) + def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n) + def incrementParallelListingJobCount(n: Int): Unit = METRIC_PARALLEL_LISTING_JOB_COUNT.inc(n) } diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 8f83668d7902..b3f8bfe8b1d4 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -46,5 +46,5 @@ trait BlockDataManager { /** * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ - def releaseLock(blockId: BlockId): Unit + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6..8f579c5a3033 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 41832e835474..50d977a92da5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat private[spark] class BinaryFileRDD[T]( - sc: SparkContext, + @transient private val sc: SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], @@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T]( case _ => } val jobContext = new JobContextImpl(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) + inputFormat.setMinPartitions(sc, jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f0..a091f06b4ed7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c5..14331dfd0c98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -152,13 +152,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the right except for the last which is closed + * to the right except for the last which is closed. * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * e.g {@code <=x<10, 10<=x<20, 20<=x<=50} * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de09..b56ebf4df06e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.hadoopRDD()` */ @DeveloperApi class HadoopRDD[K, V]( @@ -210,12 +210,12 @@ class HadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] + private val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - val jobConf = getJobConf() + private val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { @@ -225,7 +225,7 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None @@ -235,28 +235,39 @@ class HadoopRDD[K, V]( // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + private var reader: RecordReader[K, V] = null + private val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + reader = + try { + inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() + private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true } if (!finished) { inputMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala index f40d4c8e0a4d..960c91a154db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -22,6 +22,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * This holds file names of the current Spark task. This is used in HadoopRDD, * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. + * + * The returned value should never be null but empty string if it is unknown. */ private[spark] object InputFileNameHolder { /** @@ -32,9 +34,15 @@ private[spark] object InputFileNameHolder { override protected def initialValue(): UTF8String = UTF8String.fromString("") } + /** + * Returns the holding file name or empty string if it is unknown. + */ def getInputFileName(): UTF8String = inputFileName.get() - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + private[spark] def setInputFileName(file: String) = { + require(file != null, "The input file name cannot be null") + inputFileName.set(UTF8String.fromString(file)) + } private[spark] def unsetInputFileName(): Unit = inputFileName.remove() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0970b9807167..aab46b8954bf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -41,7 +41,10 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -151,7 +154,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -191,7 +197,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb65887..6168d979032a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.newAPIHadoopRDD()` */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } @@ -132,12 +132,12 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] + private val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf + private val conf = getConf - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { @@ -147,46 +147,62 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + private val getBytesReadCallback: Option[() => Long] = + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - val format = inputFormatClass.newInstance + private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + private var finished = false + private var reader = + try { + val _reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + private var havePair = false + private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { try { finished = !reader.nextKeyValue } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true } if (finished) { // Close and release the reader here; close() will also be called when the task diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad74..dc123e23b781 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.nio.ByteBuffer import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{Date, HashMap => JHashMap, Locale} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ @@ -59,8 +59,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -68,6 +68,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -363,7 +366,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -398,9 +401,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is + * greater than `p`) would trigger sparse representation of registers, which may reduce the + * memory consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -490,11 +493,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -514,11 +517,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -635,9 +638,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { groupByKey(defaultPartitioner(self)) @@ -907,20 +910,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return an RDD with the pairs from `this` whose keys are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope { subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag]( other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = self.withScope { subtractByKey(other, new HashPartitioner(numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope { new SubtractedRDD[K, V, W](self, other, p) } @@ -1016,7 +1023,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1070,7 +1077,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1079,7 +1086,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) val jobtrackerID = formatter.format(new Date()) val stageId = self.id val jobConfiguration = job.getConfiguration diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee..ce75a16031a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b..6a89ea878646 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb..199a3770ac4a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. + * Spark paper + * for more details on RDD internals. */ abstract class RDD[T: ClassTag]( @transient private var _sc: SparkContext, @@ -195,10 +195,14 @@ abstract class RDD[T: ClassTag]( } } - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): this.type = persist() /** @@ -419,7 +423,8 @@ abstract class RDD[T: ClassTag]( * * This results in a narrow dependency, e.g. if you go from 1000 partitions * to 100 partitions, there will not be a shuffle, instead each of the 100 - * new partitions will claim 10 of the current partitions. + * new partitions will claim 10 of the current partitions. If a larger number + * of partitions is requested, it will stay at the current number of partitions. * * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, * this may result in your computation taking place on fewer nodes than @@ -428,7 +433,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -469,8 +474,12 @@ abstract class RDD[T: ClassTag]( * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -534,13 +543,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -615,7 +624,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -627,7 +636,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -643,7 +652,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -671,9 +680,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) @@ -684,9 +693,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K]( f: T => K, @@ -699,9 +708,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = withScope { @@ -747,8 +756,10 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2) {f(e)} + * {{{ + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2) {f(e)} + * }}} * @param separateWorkingDir Use separate working directories for each task. * @param bufferSize Buffer size for the stdin writer for the piped process. * @param encoding Char encoding used for interacting (via stdin, stdout and stderr) with @@ -788,14 +799,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { @@ -906,7 +929,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -919,7 +942,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1167,10 +1190,15 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which - * returns an RDD[T, Long] instead of a map. + * To handle very large results, consider using + * + * {{{ + * rdd.map(x => (x, 1L)).reduceByKey(_ + _) + * }}} + * + * , which returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope { map(value => (value, null)).countByKey() @@ -1208,9 +1236,9 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater + * than `p`) would trigger sparse representation of registers, which may reduce the memory + * consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -1257,7 +1285,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1271,7 +1299,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1290,10 +1318,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1355,7 +1383,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1378,7 +1406,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1423,7 +1451,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) @@ -1583,14 +1611,15 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD is checkpointed and materialized, either reliably or locally. */ - def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + def isCheckpointed: Boolean = isCheckpointedAndMaterialized /** * Return whether this RDD is checkpointed and materialized, either reliably or locally. * This is introduced as an alias for `isCheckpointed` to clarify the semantics of the * return value. Exposed for testing. */ - private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed + private[spark] def isCheckpointedAndMaterialized: Boolean = + checkpointData.exists(_.isCheckpointed) /** * Return whether this RDD is marked for local checkpointing. @@ -1719,7 +1748,7 @@ abstract class RDD[T: ClassTag]( /** * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * to the original parent RDDs are removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6be..6c552d4d1251 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,7 +23,8 @@ import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> checkpointing in progress --> checkpointed ]. + * + * [ Initialized --{@literal >} checkpointing in progress --{@literal >} checkpointed ] */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value @@ -32,7 +33,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067..7f399ecf81a0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c7..86a332790fb0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index ad1fddbde7b0..60e383afadf1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport} +import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c6..8425b211d6ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala index d8a80aa5aeb1..e00bc22aba44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala +++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala @@ -35,14 +35,14 @@ trait PartitionCoalescer { * @param maxPartitions the maximum number of partitions to have after coalescing * @param parent the parent RDD whose partitions to coalesce * @return an array of [[PartitionGroup]]s, where each element is itself an array of - * [[Partition]]s and represents a partition after coalescing is performed. + * `Partition`s and represents a partition after coalescing is performed. */ def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] } /** * ::DeveloperApi:: - * A group of [[Partition]]s + * A group of `Partition`s * @param prefLoc preferred location for the partition group */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index f527ec86ab7b..117f51c5b8f2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 579122868afc..530743c03640 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -146,7 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * @param uri URI with location of the file. */ def openChannel(uri: String): ReadableByteChannel - } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e51649a1ecce..e56943da1303 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,11 +407,9 @@ private[netty] class NettyRpcEnv( } } - } private[netty] object NettyRpcEnv extends Logging { - /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 6c090ada5ae9..a7b7f58376f6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -56,7 +56,7 @@ private[netty] case class RpcOutboxMessage( content: ByteBuffer, _onFailure: (Throwable) => Unit, _onSuccess: (TransportClient, ByteBuffer) => Unit) - extends OutboxMessage with RpcResponseCallback { + extends OutboxMessage with RpcResponseCallback with Logging { private var client: TransportClient = _ private var requestId: Long = _ @@ -67,8 +67,11 @@ private[netty] case class RpcOutboxMessage( } def onTimeout(): Unit = { - require(client != null, "TransportClient has not yet been set.") - client.removeRpcRequest(requestId) + if (client != null) { + client.removeRpcRequest(requestId) + } else { + logError("Ask timeout before connecting successfully") + } } override def onFailure(e: Throwable): Unit = { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index 99f20da2d66a..430dcc50ba71 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an `RpcEndpoint` exists. * * This is used when setting up a remote endpoint reference. */ @@ -35,6 +35,6 @@ private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher private[netty] object RpcEndpointVerifier { val NAME = "endpoint-verifier" - /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */ case class CheckExistence(name: String) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afe..0a5fe5a1d3ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f2517401cb76..01a95c06fc69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1660,7 +1660,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } - dagScheduler.sc.stop() + dagScheduler.sc.stopInNewThread() } override def onStop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index a6b032cc0084..66ab9a52b778 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -153,7 +153,7 @@ object InputFormatInfo { a) For each host, count number of splits hosted on that host. b) Decrement the currently allocated containers on that host. - c) Compute rack info for each host and update rack -> count map based on (b). + c) Compute rack info for each host and update rack to count map based on (b). d) Allocate nodes based on (c) e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node (even if data locality on that is very high) : this is to prevent fragility of job if a diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 0bd5a6bc59a9..08e05ae0c095 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -22,6 +22,7 @@ import java.io.{InputStream, IOException} import scala.io.Source import com.fasterxml.jackson.core.JsonParseException +import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging @@ -87,6 +88,12 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. // It's safe since no place uses them. logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") + case e: UnrecognizedPropertyException if e.getMessage != null && e.getMessage.startsWith( + "Unrecognized field \"queryStatus\" " + + "(class org.apache.spark.sql.streaming.StreamingQueryListener$") => + // Ignore events generated by Structured Streaming in Spark 2.0.2 + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 1e7c63af2e79..d19353f2a993 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 66d6790e168f..31011de85bf7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -42,7 +42,7 @@ import org.apache.spark.shuffle.ShuffleWriter * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. * * The parameters below are optional: diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e..112b08f2c03a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,7 @@ import org.apache.spark.util._ * @param stageId id of the stage this task belongs to * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. * * The parameters below are optional: diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1c7c81c488c3..45c742cbff5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.SerializableBuffer /** * Description of a task that gets passed onto executors to be executed, usually created by - * [[TaskSetManager.resourceOffer]]. + * `TaskSetManager.resourceOffer`. */ private[spark] class TaskDescription( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 3e3f1ad031e6..b03cfe4f0dc4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -93,10 +93,12 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Number of tasks running on each executor - private val executorIdToTaskCount = new HashMap[String, Int] + // IDs of the tasks running on each executor + private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] - def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap + def runningTasksByExecutors: Map[String, Int] = synchronized { + executorIdToRunningTaskIds.toMap.mapValues(_.size) + } // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -264,7 +266,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId - executorIdToTaskCount(execId) += 1 + executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) launchedTask = true @@ -294,11 +296,11 @@ private[spark] class TaskSchedulerImpl( if (!hostToExecutors.contains(o.host)) { hostToExecutors(o.host) = new HashSet[String]() } - if (!executorIdToTaskCount.contains(o.executorId)) { + if (!executorIdToRunningTaskIds.contains(o.executorId)) { hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host - executorIdToTaskCount(o.executorId) = 0 + executorIdToRunningTaskIds(o.executorId) = HashSet[Long]() newExecAvail = true } for (rack <- getRackForHost(o.host)) { @@ -349,38 +351,34 @@ private[spark] class TaskSchedulerImpl( var reason: Option[ExecutorLossReason] = None synchronized { try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - - if (executorIdToTaskCount.contains(execId)) { - reason = Some( - SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) - removeExecutor(execId, reason.get) - failedExecutor = Some(execId) - } - } taskIdToTaskSetManager.get(tid) match { case Some(taskSet) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid).foreach { execId => - if (executorIdToTaskCount.contains(execId)) { - executorIdToTaskCount(execId) -= 1 - } + if (state == TaskState.LOST) { + // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, + // where each executor corresponds to a single task, so mark the executor as failed. + val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException( + "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)")) + if (executorIdToRunningTaskIds.contains(execId)) { + reason = Some( + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) + failedExecutor = Some(execId) } } - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + if (TaskState.isFinished(state)) { + cleanupTaskState(tid) taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + if (state == TaskState.FINISHED) { + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") + "likely the result of receiving duplicate task finished status updates) or its " + + "executor has been marked as failed.") .format(state, tid)) } } catch { @@ -491,7 +489,7 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (executorIdToTaskCount.contains(executorId)) { + if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) @@ -533,13 +531,31 @@ private[spark] class TaskSchedulerImpl( logError(s"Lost executor $executorId on $hostPort: $reason") } + /** + * Cleans up the TaskScheduler's state for tracking the given task. + */ + private def cleanupTaskState(tid: Long): Unit = { + taskIdToTaskSetManager.remove(tid) + taskIdToExecutorId.remove(tid).foreach { executorId => + executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) } + } + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - executorIdToTaskCount -= executorId + // The tasks on the lost executor may not send any more status updates (because the executor + // has been lost), so they should be cleaned up here. + executorIdToRunningTaskIds.remove(executorId).foreach { taskIds => + logDebug("Cleaning up TaskScheduler state for tasks " + + s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId") + // We do not notify the TaskSetManager of the task failures because that will + // happen below in the rootPool.executorLost() call. + taskIds.foreach(cleanupTaskState) + } val host = executorIdToHost(executorId) val execs = hostToExecutors.getOrElse(host, new HashSet) @@ -577,11 +593,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - executorIdToTaskCount.contains(execId) + executorIdToRunningTaskIds.contains(execId) } def isExecutorBusy(execId: String): Boolean = synchronized { - executorIdToTaskCount.getOrElse(execId, -1) > 0 + executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty) } // By default, rack is unknown diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d151..0a4f19d76073 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { - case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage + + case class SparkAppConfig( + sparkProperties: Seq[(String, String)], + ioEncryptionKey: Option[Array[Byte]]) + extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 10d55c87fb8d..6fa239e4fee6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -206,8 +210,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) - case RetrieveSparkProps => - context.reply(sparkProperties) + case RetrieveSparkAppConfig => + val reply = SparkAppConfig(sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey()) + context.reply(reply) } // Make fake resource offers on all executors @@ -388,6 +394,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -461,12 +468,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -498,6 +514,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -573,8 +590,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cf..6f75a4791e94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,21 +126,21 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) } finally { // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + sc.stopInNewThread() } } } @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8f15f50bee81..8e3436f13480 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -18,14 +18,13 @@ package org.apache.spark.security import java.io.{InputStream, OutputStream} import java.util.Properties +import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ -import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -33,10 +32,6 @@ import org.apache.spark.internal.config._ * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils extends Logging { - /** - * Constants and variables for spark IO encryption - */ - val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 @@ -46,32 +41,30 @@ private[spark] object CryptoStreamUtils extends Logging { val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** - * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. */ def createCryptoOutputStream( os: OutputStream, - sparkConf: SparkConf): OutputStream = { + sparkConf: SparkConf, + key: Array[Byte]): OutputStream = { val properties = toCryptoConf(sparkConf) val iv = createInitializationVector(properties) os.write(iv) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoOutputStream(transformationStr, properties, os, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } /** - * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption. */ def createCryptoInputStream( is: InputStream, - sparkConf: SparkConf): InputStream = { + sparkConf: SparkConf, + key: Array[Byte]): InputStream = { val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) is.read(iv, 0, iv.length) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) new CryptoInputStream(transformationStr, properties, is, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) @@ -91,6 +84,17 @@ private[spark] object CryptoStreamUtils extends Logging { props } + /** + * Creates a new encryption key. + */ + def createKey(conf: SparkConf): Array[Byte] = { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + keyGen.generateKey().getEncoded() + } + /** * This method to generate an IV (Initialization Vector) using secure random. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b..f60dcfddfdc2 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe107..7eb2da1c2748 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -43,9 +43,10 @@ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, S import org.apache.spark.util.collection.CompactBuffer /** - * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * A Spark serializer that uses the + * Kryo serialization library. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0c..afe6cd86059f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2156d576f187..dd98ea265ce4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * Component which configures serialization, compression and encryption for various Spark * components, including automatic selection of which [[Serializer]] to use for shuffles. */ -private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { +private[spark] class SerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) { + + def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) private[this] val kryoSerializer = new KryoSerializer(conf) @@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -73,12 +75,17 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } - def getSerializer(ct: ClassTag[_]): Serializer = { - if (canUseKryo(ct)) { + // SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst + // result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be + // a rational choice to close `kryo auto pick` feature for streaming in the first step. + def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = { + if (autoPick && canUseKryo(ct)) { kryoSerializer } else { defaultSerializer @@ -124,15 +131,19 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + def wrapForEncryption(s: InputStream): InputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + def wrapForEncryption(s: OutputStream): OutputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } + .getOrElse(s) } /** @@ -155,24 +166,32 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar outputStream: OutputStream, values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ - def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) + def dataSerialize[T: ClassTag]( + blockId: BlockId, + values: Iterator[T], + allowEncryption: Boolean = true): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], + allowEncryption = allowEncryption) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_]): ChunkedByteBuffer = { + classTag: ClassTag[_], + allowEncryption: Boolean = true): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) - val ser = getSerializer(classTag).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(classTag, autoPick).newInstance() + val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream + ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -182,12 +201,15 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream) + inputStream: InputStream, + maybeEncrypted: Boolean = true) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(classTag) + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream + getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapStream(blockId, stream)) + .deserializeStream(wrapForCompression(blockId, decrypted)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 17bc04303fa8..d2df77f1c5dc 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1 import java.util.zip.ZipOutputStream import javax.servlet.ServletContext +import javax.servlet.http.HttpServletRequest import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class ApiRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJobs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs") def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs/{jobId: \\d+}") def getJob(@PathParam("appId") appId: String): OneJobResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneJobResource(ui) } } @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJob( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneJobResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneJobResource(ui) } } @Path("applications/{appId}/executors") def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ExecutorListResource(ui) } } @Path("applications/{appId}/allexecutors") def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllExecutorListResource(ui) } } @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ExecutorListResource(ui) } } @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getAllExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllExecutorListResource(ui) } } - @Path("applications/{appId}/stages") def getStages(@PathParam("appId") appId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllStagesResource(ui) } } @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStages( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") def getStage(@PathParam("appId") appId: String): OneStageResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneStageResource(ui) } } @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStage( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneStageResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneStageResource(ui) } } @Path("applications/{appId}/storage/rdd") def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllRDDResource(ui) } } @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdds( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllRDDResource(ui) } } @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneRDDResource(ui) } } @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdd( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneRDDResource(ui) } } @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -234,19 +247,6 @@ private[spark] trait UIRoot { .status(Response.Status.SERVICE_UNAVAILABLE) .build() } - - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - getSparkUI(appKey) match { - case Some(ui) => - f(ui) - case None => throw new NotFoundException("no such app: " + appId) - } - } def securityManager: SecurityManager } @@ -263,13 +263,37 @@ private[v1] object UIRootFromServletContext { } } -private[v1] trait UIRootFromServletContext { +private[v1] trait ApiRequestContext { @Context - var servletContext: ServletContext = _ + protected var servletContext: ServletContext = _ + + @Context + protected var httpRequest: HttpServletRequest = _ def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) + + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + uiRoot.getSparkUI(appKey) match { + case Some(ui) => + val user = httpRequest.getRemoteUser() + if (!ui.securityManager.checkUIViewPermissions(user)) { + throw new ForbiddenException(raw"""user "$user" is not authorized""") + } + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } } +private[v1] class ForbiddenException(msg: String) extends WebApplicationException( + Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + private[v1] class NotFoundException(msg: String) extends WebApplicationException( new NoSuchElementException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573d..76af33c1a18d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index b4a991eda35f..1cd37185d660 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response import javax.ws.rs.ext.Provider @Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { +private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { override def filter(req: ContainerRequestContext): Unit = { - val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + val user = httpRequest.getRemoteUser() if (!uiRoot.securityManager.checkUIViewPermissions(user)) { req.abortWith( Response .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user"is not authorized""") + .entity(raw"""user "$user" is not authorized""") .build() ) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd238222..d8d5e8958b23 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index dd8f5bacb9f6..c0e18e54f410 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging { /** * Release a lock on the given block. + * In case a TaskContext is not propagated properly to all child threads for the task, we fail to + * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock. + * + * See SPARK-18406 for more discussion of this issue. */ - def unlock(blockId: BlockId): Unit = synchronized { - logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized { + val taskId = taskAttemptId.getOrElse(currentTaskAttemptId) + logTrace(s"Task $taskId releasing lock for $blockId") val info = get(blockId).getOrElse { throw new IllegalStateException(s"Block $blockId not found") } if (info.writerTask != BlockInfo.NO_WRITER) { info.writerTask = BlockInfo.NO_WRITER - writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + writeLocksByTask.removeBinding(taskId, blockId) } else { assert(info.readerCount > 0, s"Block $blockId is not locked for reading") info.readerCount -= 1 - val countsForTask = readLocksByTask(currentTaskAttemptId) + val countsForTask = readLocksByTask(taskId) val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 assert(newPinCountForTask >= 0, - s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + s"Task $taskId release lock on block $blockId more times than it acquired it") } notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0f..6ff61bdd5b7b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -28,6 +28,8 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal +import com.google.common.io.ByteStreams + import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging @@ -38,6 +40,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -62,7 +65,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -315,6 +318,9 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. + * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. */ override def putBlockData( blockId: BlockId, @@ -448,6 +454,7 @@ private[spark] class BlockManager( case Some(info) => val level = info.level logDebug(s"Level for block $blockId is $level") + val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId()) if (level.useMemory && memoryStore.contains(blockId)) { val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get @@ -455,7 +462,12 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream( blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) } - val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + // We need to capture the current taskId in case the iterator completion is triggered + // from a different thread which does not have TaskContext set; see SPARK-18406 for + // discussion. + val ci = CompletionIterator[Any, Iterator[Any]](iter, { + releaseLock(blockId, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { val iterToReturn: Iterator[Any] = { @@ -472,7 +484,9 @@ private[spark] class BlockManager( serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, { + releaseLock(blockId, taskAttemptId) + }) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -648,10 +662,13 @@ private[spark] class BlockManager( } /** - * Release a lock on the given block. + * Release a lock on the given block with explicit TID. + * The param `taskAttemptId` should be passed in case we can't get the correct TID from + * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child + * thread. */ - def releaseLock(blockId: BlockId): Unit = { - blockInfoManager.unlock(blockId) + def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = { + blockInfoManager.unlock(blockId, taskAttemptId) } /** @@ -745,24 +762,54 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } /** * Put a new block of serialized bytes to the block manager. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * + * @param encrypt If true, asks the block manager to encrypt the data block before storing, + * when I/O encryption is enabled. This is required for blocks that have been + * read from unencrypted sources, since all the BlockManager read APIs + * automatically do decryption. * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true): Boolean = { + tellMaster: Boolean = true, + encrypt: Boolean = false): Boolean = { require(bytes != null, "Bytes is null") - doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) + + val bytesToStore = + if (encrypt && securityManager.ioEncryptionKey.isDefined) { + try { + val data = bytes.toByteBuffer + val in = new ByteBufferInputStream(data) + val byteBufOut = new ByteBufferOutputStream(data.remaining()) + val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, + securityManager.ioEncryptionKey.get) + try { + ByteStreams.copy(in, out) + } finally { + in.close() + out.close() + } + new ChunkedByteBuffer(byteBufOut.toByteBuffer) + } finally { + bytes.dispose() + } + } else { + bytes + } + + doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -771,6 +818,9 @@ private[spark] class BlockManager( * * If the block already exists, this method will not overwrite it. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param keepReadLock if true, this method will hold the read lock when it returns (even if the * block already exists). If false, this method will hold no locks when it * returns. @@ -814,7 +864,15 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + val memoryMode = level.memoryMode + memoryStore.putBytes(blockId, size, memoryMode, () => { + if (memoryMode == MemoryMode.OFF_HEAP && + bytes.chunks.exists(buffer => !buffer.isDirect)) { + bytes.copy(Platform.allocateDirectBuffer) + } else { + bytes + } + }) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -1019,7 +1077,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.dispose() + bytesToReplicate.unmap() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 6bded9270050..d71acbb4cf77 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,7 +43,7 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerSlave /** - * Driver -> Executor message to trigger a thread dump. + * Driver to Executor message to trigger a thread dump. */ case object TriggerThreadDump extends ToBlockManagerSlave diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bf087af16a5b..bb8a684b4c7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -89,17 +89,18 @@ class RandomBlockReplicationPolicy prioritizedPeers } + // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage - * [[http://math.stackexchange.com/questions/178690/ - * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * minimizing space usage. Please see + * here. * * @param n total number of indices * @param m number of samples needed * @param r random number generator * @return list of m random unique indices */ + // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae159..eb3ff926372a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this @@ -128,16 +128,19 @@ private[spark] class DiskBlockObjectWriter( */ private def closeResources(): Unit = { if (initialized) { - mcs.manualClose() - channel = null - mcs = null - bs = null - fos = null - ts = null - objOut = null - initialized = false - streamOpen = false - hasBeenClosed = true + Utils.tryWithSafeFinally { + mcs.manualClose() + } { + channel = null + mcs = null + bs = null + fos = null + ts = null + objOut = null + initialized = false + streamOpen = false + hasBeenClosed = true + } } } @@ -199,26 +202,29 @@ private[spark] class DiskBlockObjectWriter( def revertPartialWritesAndClose(): File = { // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. - try { + Utils.tryWithSafeFinally { if (initialized) { writeMetrics.decBytesWritten(reportedPosition - committedPosition) writeMetrics.decRecordsWritten(numRecordsWritten) streamOpen = false closeResources() } - - val truncateStream = new FileOutputStream(file, true) + } { + var truncateStream: FileOutputStream = null try { + truncateStream = new FileOutputStream(file, true) truncateStream.getChannel.truncate(committedPosition) - file + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) } finally { - truncateStream.close() + if (truncateStream != null) { + truncateStream.close() + truncateStream = null + } } - } catch { - case e: Exception => - logError("Uncaught exception while reverting partial writes to file " + file, e) - file } + file } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4dc2f362329a..7eda6e97a81e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -247,7 +247,7 @@ final class ShuffleBlockFetcherIterator( /** * Fetch the local blocks while we are fetching remote blocks. This is ok because - * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { @@ -304,6 +304,10 @@ final class ShuffleBlockFetcherIterator( * Throws a FetchFailedException if the next block could not be fetched. */ override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw new NoSuchElementException + } + numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -423,7 +427,7 @@ object ShuffleBlockFetcherIterator { * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. - * @param buf [[ManagedBuffer]] for the content. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0..5efdd23f79a2 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) @@ -231,22 +236,60 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { + // Ewwww... Reflection!!! See the unmap method for justification + private val memoryMappedBufferFileDescriptorField = { + val mappedBufferClass = classOf[java.nio.MappedByteBuffer] + val fdField = mappedBufferClass.getDeclaredField("fd") + fdField.setAccessible(true) + fdField + } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun + * API that will cause errors if one attempts to read from the disposed buffer. However, neither + * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put + * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of + * off-heap memory or huge numbers of open files. There's unfortunately no standard API to + * manually dispose of these kinds of buffers. + * + * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() + logTrace(s"Disposing of $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + } + } + + /** + * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will + * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of + * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage + * collection may lead to huge numbers of open files. There's unfortunately no standard API to + * manually unmap memory-mapped buffers. + * + * See also [[dispose]] + */ + def unmap(buffer: ByteBuffer): Unit = { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the + // JDK does not provide a public API to distinguish between direct buffers and memory-mapped + // buffers. As an alternative, we peek beneath the curtains and look for a non-null file + // descriptor in mappedByteBuffer + if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { + logTrace(s"Unmapping $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) } } } + private def cleanDirectBuffer(buffer: DirectBuffer) = { + val cleaner = buffer.cleaner() + if (cleaner != null) { + cleaner.clean() + } + } + /** * Update the given list of RDDInfo with the given list of storage statuses. * This method overwrites the old values stored in the RDDInfo's. diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 095d32407f34..929a0808bd23 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} import org.apache.spark.unsafe.Platform import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -334,7 +334,8 @@ private[spark] class MemoryStore( val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { - val ser = serializerManager.getSerializer(classTag).newInstance() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) } @@ -693,7 +694,7 @@ private[storage] class PartiallyUnrolledIterator[T]( } override def next(): T = { - if (unrolled == null) { + if (unrolled == null || !unrolled.hasNext) { rest.next() } else { unrolled.next() diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 35c3c8d00f99..639b8577617f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils */ private[spark] object JettyUtils extends Logging { + val SPARK_CONNECTOR_NAME = "Spark" + val REDIRECT_CONNECTOR_NAME = "HttpsRedirect" + // Base type for a function that returns something based on an HTTP request. Allows for // implicit conversion from many types of functions to jetty Handlers. type Responder[T] = HttpServletRequest => T @@ -87,9 +90,9 @@ private[spark] object JettyUtils extends Logging { response.setHeader("X-Frame-Options", xFrameOptionsValue) response.getWriter.print(servletParams.extractFn(result)) } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setStatus(HttpServletResponse.SC_FORBIDDEN) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + response.sendError(HttpServletResponse.SC_FORBIDDEN, "User is not authorized to access this page.") } } catch { @@ -274,17 +277,18 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection addFilters(handlers, conf) val gzipHandlers = handlers.map { h => + h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler gzipHandler.setHandler(h) gzipHandler } // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): (Server, Int) = { + def connect(currentPort: Int): ((Server, Option[Int]), Int) = { val pool = new QueuedThreadPool if (serverName.nonEmpty) { pool.setName(serverName) @@ -292,7 +296,9 @@ private[spark] object JettyUtils extends Logging { pool.setDaemon(true) val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector] + val connectors = new ArrayBuffer[ServerConnector]() + val collection = new ContextHandlerCollection + // Create a connector on port currentPort to listen for HTTP requests val httpConnector = new ServerConnector( server, @@ -306,26 +312,33 @@ private[spark] object JettyUtils extends Logging { httpConnector.setPort(currentPort) connectors += httpConnector - sslOptions.createJettySslContextFactory().foreach { factory => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) - - connectors += connector - - // redirect the HTTP requests to HTTPS port - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + val httpsConnector = sslOptions.createJettySslContextFactory() match { + case Some(factory) => + // If the new port wraps around, do not try a privileged port. + val securePort = + if (currentPort != 0) { + (currentPort + 400 - 1024) % (65536 - 1024) + 1024 + } else { + 0 + } + val scheme = "https" + // Create a connector on port securePort to listen for HTTPS requests + val connector = new ServerConnector(server, factory) + connector.setPort(securePort) + connector.setName(SPARK_CONNECTOR_NAME) + connectors += connector + + // redirect the HTTP requests to HTTPS port + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + collection.addHandler(createRedirectHttpsHandler(connector, scheme)) + Some(connector) + + case None => + // No SSL, so the HTTP connector becomes the official one where all contexts bind. + httpConnector.setName(SPARK_CONNECTOR_NAME) + None } - gzipHandlers.foreach(collection.addHandler) // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 @@ -337,17 +350,20 @@ private[spark] object JettyUtils extends Logging { // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 } - server.setConnectors(connectors.toArray) pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) val errorHandler = new ErrorHandler() errorHandler.setShowStacks(true) errorHandler.setServer(server) server.addBean(errorHandler) + + gzipHandlers.foreach(collection.addHandler) server.setHandler(collection) + + server.setConnectors(connectors.toArray) try { server.start() - (server, httpConnector.getLocalPort) + ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) } catch { case e: Exception => server.stop() @@ -356,13 +372,18 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) - ServerInfo(server, boundPort, collection) + val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, + serverName) + ServerInfo(server, boundPort, securePort, + server.getHandler().asInstanceOf[ContextHandlerCollection]) } - private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { + private def createRedirectHttpsHandler( + httpsConnector: ServerConnector, + scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") + redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -372,8 +393,8 @@ private[spark] object JettyUtils extends Logging { if (baseRequest.isSecure) { return } - val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, - baseRequest.getRequestURI, baseRequest.getQueryString) + val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, + httpsConnector.getLocalPort, baseRequest.getRequestURI, baseRequest.getQueryString) response.setContentLength(0) response.encodeRedirectURL(httpsURI) response.sendRedirect(httpsURI) @@ -442,7 +463,23 @@ private[spark] object JettyUtils extends Logging { private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) { + securePort: Option[Int], + private val rootHandler: ContextHandlerCollection) { + + def addHandler(handler: ContextHandler): Unit = { + handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME)) + rootHandler.addHandler(handler) + if (!handler.isStarted()) { + handler.start() + } + } + + def removeHandler(handler: ContextHandler): Unit = { + rootHandler.removeHandler(handler) + if (handler.isStarted) { + handler.stop() + } + } def stop(): Unit = { server.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 2a7c16b04bf7..79974df2603f 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -175,13 +175,14 @@ private[ui] trait PagedTable[T] { val hiddenFormFields = { if (goButtonFormPath.contains('?')) { - val querystring = goButtonFormPath.split("\\?", 2)(1) + val queryString = goButtonFormPath.split("\\?", 2)(1) + val search = queryString.split("#")(0) Splitter .on('&') .trimResults() .omitEmptyStrings() .withKeyValueSeparator("=") - .split(querystring) + .split(search) .asScala .filterKeys(_ != pageSizeFormField) .filterKeys(_ != prevPageSizeFormField) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62..40aa7ff8a786 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -34,9 +36,12 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) @@ -170,6 +175,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { @@ -420,8 +426,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like