diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index beacc39500aaa..33ed2c98e6e7e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -277,6 +277,7 @@ export("as.DataFrame", "read.parquet", "read.text", "sql", + "str", "table", "tableNames", "tables", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 7a7aef27ccb24..a6c6a1d075288 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2300,3 +2300,76 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Display the structure of a DataFrame, including column names, column types, as well as a +#' a small sample of rows. +#' @name str +#' @title Compactly display the structure of a dataset +#' @rdname str +#' @family DataFrame functions +#' @param object a DataFrame +#' @examples \dontrun{ +#' # Create a DataFrame from the Iris dataset +#' irisDF <- createDataFrame(sqlContext, iris) +#' +#' # Show the structure of the DataFrame +#' str(irisDF) +#' } +setMethod("str", + signature(object = "DataFrame"), + function(object) { + + # TODO: These could be made global parameters, though in R it's not the case + MAX_CHAR_PER_ROW <- 120 + MAX_COLS <- 100 + + # Get the column names and types of the DataFrame + names <- names(object) + types <- coltypes(object) + + # Get the first elements of the dataset. Limit number of columns accordingly + localDF <- if (ncol(object) > MAX_COLS) { + head(object[, c(1:MAX_COLS)]) + } else { + head(object) + } + + # The number of observations will not be displayed as computing the + # number of rows is a very expensive operation + cat(paste0("'", class(object), "': ", length(names), " variables:\n")) + + if (nrow(localDF) > 0) { + for (i in 1 : ncol(localDF)) { + # Get the first elements for each column + + firstElements <- if (types[i] == "character") { + paste(paste0("\"", localDF[,i], "\""), collapse = " ") + } else { + paste(localDF[,i], collapse = " ") + } + + # Add the corresponding number of spaces for alignment + spaces <- paste(rep(" ", max(nchar(names) - nchar(names[i]))), collapse="") + + # Get the short type. For 'character', it would be 'chr'; + # 'for numeric', it's 'num', etc. + dataType <- SHORT_TYPES[[types[i]]] + if (is.null(dataType)) { + dataType <- substring(types[i], 1, 3) + } + + # Concatenate the colnames, coltypes, and first + # elements of each column + line <- paste0(" $ ", names[i], spaces, ": ", + dataType, " ",firstElements) + + # Chop off extra characters if this is too long + cat(substr(line, 1, MAX_CHAR_PER_ROW)) + cat("\n") + } + + if (ncol(localDF) < ncol(object)) { + cat(paste0("\nDisplaying first ", ncol(localDF), " columns only.")) + } + } + }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ba6861709754d..816bbd0d8ca02 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -378,7 +378,6 @@ setGeneric("subtractByKey", setGeneric("value", function(bcast) { standardGeneric("value") }) - #################### DataFrame Methods ######################## #' @rdname agg @@ -389,6 +388,14 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname as.data.frame +#' @export +setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") + #' @rdname columns #' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) @@ -525,13 +532,12 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) -#' @rdname withColumn #' @export -setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) +setGeneric("str") -#' @rdname write.df +#' @rdname mutate #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export @@ -593,6 +599,10 @@ setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) +#' @rdname with +#' @export +setGeneric("with") + #' @rdname withColumn #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) @@ -602,6 +612,9 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) ###################### Column Methods ########################## @@ -1105,7 +1118,6 @@ setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) - #' @rdname glm #' @export setGeneric("glm") @@ -1117,15 +1129,3 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") - -#' @rdname as.data.frame -#' @export -setGeneric("as.data.frame") - -#' @rdname attach -#' @export -setGeneric("attach") - -#' @rdname with -#' @export -setGeneric("with") diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1f06af7e904fe..ad048b1cd1795 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -47,10 +47,23 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) +SHORT_TYPES <- as.environment(list( + "character" = "chr", + "logical" = "logi", + "POSIXct" = "POSIXct", + "integer" = "int", + "numeric" = "num", + "raw" = "raw", + "Date" = "Date", + "map" = "map", + "array" = "array", + "struct" = "struct" +)) + # An environment for mapping R to Scala, names are R types and values are Scala types. rToSQLTypes <- as.environment(list( - "integer" = "integer", # in R, integer is 32bit - "numeric" = "double", # in R, numeric == double which is 64bit - "double" = "double", + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", "character" = "string", - "logical" = "boolean")) + "logical" = "boolean")) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 73f311e2684a2..d6e498dbf752f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1795,6 +1795,37 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { "Only atomic type is supported for column types") }) +test_that("Method str()", { + # Structure of Iris + iris2 <- iris + colnames(iris2) <- c("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species") + iris2$col <- TRUE + irisDF2 <- createDataFrame(sqlContext, iris2) + + out <- capture.output(str(irisDF2)) + expect_equal(length(out), 7) + expect_equal(out[1], "'DataFrame': 6 variables:") + expect_equal(out[2], " $ Sepal_Length: num 5.1 4.9 4.7 4.6 5 5.4") + expect_equal(out[3], " $ Sepal_Width : num 3.5 3 3.2 3.1 3.6 3.9") + expect_equal(out[4], " $ Petal_Length: num 1.4 1.4 1.3 1.5 1.4 1.7") + expect_equal(out[5], " $ Petal_Width : num 0.2 0.2 0.2 0.2 0.2 0.4") + expect_equal(out[6], paste0(" $ Species : chr \"setosa\" \"setosa\" \"", + "setosa\" \"setosa\" \"setosa\" \"setosa\"")) + expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE") + + # A random dataset with many columns. This test is to check str limits + # the number of columns. Therefore, it will suffice to check for the + # number of returned rows + x <- runif(200, 1, 10) + df <- data.frame(t(as.matrix(data.frame(x,x,x,x,x,x,x,x,x)))) + DF <- createDataFrame(sqlContext, df) + out <- capture.output(str(DF)) + expect_equal(length(out), 103) + + # Test utils:::str + expect_equal(capture.output(utils:::str(iris)), capture.output(str(iris))) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index dce1f15a2963c..98a7314457354 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -107,8 +107,10 @@ public UnsafeInMemorySorter( * Free the memory used by pointer array. */ public void free() { - consumer.freeArray(array); - array = null; + if (consumer != null) { + consumer.freeArray(array); + array = null; + } } public void reset() { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 8ad4656b4dada..3bdba922328c2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -28,9 +28,13 @@ private[v1] class ExecutorListResource(ui: SparkUI) { @GET def executorList(): Seq[ExecutorSummary] = { val listener = ui.executorsListener - val storageStatusList = listener.storageStatusList - (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId) + listener.synchronized { + // The follow codes should be protected by `listener` to make sure no executors will be + // removed before we query their status. See SPARK-12784. + val storageStatusList = listener.storageStatusList + (0 until storageStatusList.size).map { statusId => + ExecutorsPage.getExecInfo(listener, statusId) + } } } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 81a6f07ec836a..1949c4b3cbf42 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -451,4 +452,19 @@ private[spark] object UIUtils extends Logging { {desc} } } + + /** + * Decode URLParameter if URL is encoded by YARN-WebAppProxyServlet. + * Due to YARN-2844: WebAppProxyServlet cannot handle urls which contain encoded characters + * Therefore we need to decode it until we get the real URLParameter. + */ + def decodeURLParameter(urlParam: String): String = { + var param = urlParam + var decodedParam = URLDecoder.decode(param, "UTF-8") + while (param != decodedParam) { + param = decodedParam + decodedParam = URLDecoder.decode(param, "UTF-8") + } + param + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index b0a2cb4aa4d4b..32980544347a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try @@ -30,18 +29,8 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).map { - executorId => - // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when - // running in yarn-cluster mode. `request.getParameter("executorId")` will return - // "%253Cdriver%253E". Therefore we need to decode it until we get the real id. - var id = executorId - var decodedId = URLDecoder.decode(id, "UTF-8") - while (id != decodedId) { - id = decodedId - decodedId = URLDecoder.decode(id, "UTF-8") - } - id + val executorId = Option(request.getParameter("executorId")).map { executorId => + UIUtils.decodeURLParameter(executorId) }.getOrElse { throw new IllegalArgumentException(s"Missing executorId parameter") } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 1a29b0f412603..7072a152d6b69 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -52,12 +52,19 @@ private[ui] class ExecutorsPage( private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = listener.storageStatusList + val (storageStatusList, execInfo) = listener.synchronized { + // The follow codes should be protected by `listener` to make sure no executors will be + // removed before we query their status. See SPARK-12784. + val _storageStatusList = listener.storageStatusList + val _execInfo = { + for (statusId <- 0 until _storageStatusList.size) + yield ExecutorsPage.getExecInfo(listener, statusId) + } + (_storageStatusList, _execInfo) + } val maxMem = storageStatusList.map(_.maxMem).sum val memUsed = storageStatusList.map(_.memUsed).sum val diskUsed = storageStatusList.map(_.diskUsed).sum - val execInfo = for (statusId <- 0 until storageStatusList.size) yield - ExecutorsPage.getExecInfo(listener, statusId) val execInfoSorted = execInfo.sortBy(_.id) val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f3e0b38523f32..778272a6da1e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -31,8 +31,11 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = request.getParameter("poolname") - require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolName = Option(request.getParameter("poolname")).map { poolname => + UIUtils.decodeURLParameter(poolname) + }.getOrElse { + throw new IllegalArgumentException(s"Missing poolname parameter") + } val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { @@ -44,7 +47,9 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getPoolForName(poolName).get).toSeq + val pools = sc.map(_.getPoolForName(poolName).getOrElse { + throw new IllegalArgumentException(s"Unknown poolname: $poolName") + }).toSeq val poolTable = new PoolTable(pools, parent) val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 9ba2af54dacf4..ea02968733cac 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder + import scala.collection.mutable.HashMap import scala.xml.Node @@ -59,7 +61,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { case None => 0 } val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), p.name) + .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1b34ba9f03c44..5183c80ab4526 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -99,7 +99,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterTaskPageSize = request.getParameter("task.pageSize") val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) - val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 14b6ba4af489a..86bbaa20f6cf2 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.logging -import java.io.{File, FileOutputStream, InputStream} +import java.io.{File, FileOutputStream, InputStream, IOException} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.{IntParam, Utils} @@ -29,7 +29,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi extends Logging { @volatile private var outputStream: FileOutputStream = null @volatile private var markedForStop = false // has the appender been asked to stopped - @volatile private var stopped = false // has the appender stopped // Thread that reads the input stream and writes to file private val writingThread = new Thread("File appending thread for " + file) { @@ -47,11 +46,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi * or because of any error in appending */ def awaitTermination() { - synchronized { - if (!stopped) { - wait() - } - } + writingThread.join() } /** Stop the appender */ @@ -63,24 +58,28 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi protected def appendStreamToFile() { try { logDebug("Started appending thread") - openFile() - val buf = new Array[Byte](bufferSize) - var n = 0 - while (!markedForStop && n != -1) { - n = inputStream.read(buf) - if (n != -1) { - appendToFile(buf, n) + Utils.tryWithSafeFinally { + openFile() + val buf = new Array[Byte](bufferSize) + var n = 0 + while (!markedForStop && n != -1) { + try { + n = inputStream.read(buf) + } catch { + // An InputStream can throw IOException during read if the stream is closed + // asynchronously, so once appender has been flagged to stop these will be ignored + case _: IOException if markedForStop => // do nothing and proceed to stop appending + } + if (n > 0) { + appendToFile(buf, n) + } } + } { + closeFile() } } catch { case e: Exception => logError(s"Error writing stream to file $file", e) - } finally { - closeFile() - synchronized { - stopped = true - notifyAll() - } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index dd8d5ec27f87e..bc8a5d494dbd3 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -67,6 +67,20 @@ class UIUtilsSuite extends SparkFunSuite { s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") } + test("decodeURLParameter (SPARK-12708: Sorting task error in Stages Page when yarn mode.)") { + val encoded1 = "%252F" + val decoded1 = "/" + val encoded2 = "%253Cdriver%253E" + val decoded2 = "" + + assert(decoded1 === decodeURLParameter(encoded1)) + assert(decoded2 === decodeURLParameter(encoded2)) + + // verify that no affect to decoded URL. + assert(decoded1 === decodeURLParameter(decoded1)) + assert(decoded2 === decodeURLParameter(decoded2)) + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 2b76ae1f8a24b..5a14fc7b1d38a 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,14 +18,18 @@ package org.apache.spark.util import java.io._ +import java.util.concurrent.CountDownLatch import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.BeforeAndAfter - import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.log4j.{Appender, Level, Logger} +import org.apache.log4j.spi.LoggingEvent +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.{atLeast, mock, verify} +import org.scalatest.BeforeAndAfter import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} @@ -189,6 +193,67 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { testAppenderSelection[FileAppender, Any](rollingStrategy("xyz")) } + test("file appender async close stream abruptly") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = new ArgumentCaptor[LoggingEvent] + + // Make sure only logging errors + val logger = Logger.getRootLogger + logger.setLevel(Level.ERROR) + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + appender.awaitTermination() + + // If InputStream was closed without first stopping the appender, an exception will be logged + verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture) + val loggingEvent = loggingEventCaptor.getValue + assert(loggingEvent.getThrowableInformation !== null) + assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + + test("file appender async close stream gracefully") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = new ArgumentCaptor[LoggingEvent] + + // Make sure only logging errors + val logger = Logger.getRootLogger + logger.setLevel(Level.ERROR) + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + // Stop the appender before an IOException is called during read + testInputStream.latchReadStarted.await() + appender.stop() + testInputStream.latchReadProceed.countDown() + + appender.awaitTermination() + + // Make sure no IOException errors have been logged as a result of appender closing gracefully + verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture) + import scala.collection.JavaConverters._ + loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent => + assert(loggingEvent.getThrowableInformation === null + || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + } + /** * Run the rolling file appender with data and see whether all the data was written correctly * across rolled over files. @@ -229,4 +294,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { file.getName.startsWith(testFile.getName) }.foreach { _.delete() } } + + /** Used to synchronize when read is called on a stream */ + private trait LatchedInputStream extends PipedInputStream { + val latchReadStarted = new CountDownLatch(1) + val latchReadProceed = new CountDownLatch(1) + abstract override def read(): Int = { + latchReadStarted.countDown() + latchReadProceed.await() + super.read() + } + } } diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 44a316a07dfef..6d9659686f96c 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -428,7 +428,7 @@ This example follows the simple text document `Pipeline` illustrated in the figu
{% highlight scala %} -import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector @@ -466,7 +466,7 @@ model.save("/tmp/spark-logistic-regression-model") pipeline.save("/tmp/unfit-lr-model") // and load it back in during production -val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") +val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") // Prepare test documents, which are unlabeled (id, text) tuples. val test = sqlContext.createDataFrame(Seq( diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3193e17853483..ed720f1039f94 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -202,7 +202,7 @@ where each application gets more or fewer machines as it ramps up and down, but additional overhead in launching each task. This mode may be inappropriate for low-latency requirements like interactive queries or serving web requests. -To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +To run in fine-grained mode, set the `spark.mesos.coarse` property to false in your [SparkConf](configuration.html#spark-properties): {% highlight scala %} @@ -266,13 +266,11 @@ See the [configuration page](configuration.html) for information on Spark config Property NameDefaultMeaning spark.mesos.coarse - false + true - If set to true, runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine. + If set to false, runs over Mesos cluster in "fine-grained" sharing mode, where one Mesos task is created per Spark task. + Detailed information in 'Mesos Run Modes'. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1dbedaaca3d67..30a184901925c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -152,7 +152,7 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private var window = 5 - private var trainWordsCount = 0 + private var trainWordsCount = 0L private var vocabSize = 0 @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] @@ -160,13 +160,13 @@ class Word2Vec extends Serializable with Logging { private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) + .filter(_._2 >= minCount) .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) - .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) @@ -180,7 +180,7 @@ class Word2Vec extends Serializable with Logging { trainWordsCount += vocab(a).cn a += 1 } - logInfo("trainWordsCount = " + trainWordsCount) + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } private def createExpTable(): Array[Float] = { @@ -330,7 +330,7 @@ class Word2Vec extends Serializable with Logging { val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((syn0Global, syn1Global, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 23c8d7c7c8075..1c583a45153ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -109,7 +109,9 @@ private[stat] object ChiSqTest extends Logging { } i += 1 distinctLabels += label - features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + val brzFeatures = features.toBreeze + (startCol until endCol).map { col => + val feature = brzFeatures(col) allDistinctFeatures(col) += feature (col, feature, label) } @@ -122,7 +124,7 @@ private[stat] object ChiSqTest extends Logging { pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap } val numLabels = labels.size - pairCounts.keys.groupBy(_._1).map { case (col, keys) => + pairCounts.keys.groupBy(_._1).foreach { case (col, keys) => val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap val numRows = features.size val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 99331297c19f0..26cafca8b8381 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -76,4 +76,6 @@ # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + with open(_pythonstartup) as f: + code = compile(f.read(), _pythonstartup, 'exec') + exec(code) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index f0697613cff3b..f7596300e89f1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -28,11 +28,13 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector @@ -47,6 +49,7 @@ private[hive] object HiveShim { // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) val UNLIMITED_DECIMAL_PRECISION = 38 val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* * This function in hive-0.13 become private, but we have to do this to walkaround hive bug @@ -125,6 +128,26 @@ private[hive] object HiveShim { // for Serialization def this() = this(null) + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + @transient def deserializeObjectByKryo[T: ClassTag]( kryo: Kryo, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9deb1a6db15ad..f8b0f01f9a873 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -350,6 +350,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sqlContext.dropTempTable("testUDF") } + test("Hive UDF in group by") { + Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") + val count = sql("select date(cast(test_date as timestamp))" + + " from tab1 group by date(cast(test_date as timestamp))").count() + assert(count == 1) + } + test("SPARK-11522 select input_file_name from non-parquet table"){ withTempDir { tempDir =>